Denis Mazur 1 jaar geleden
bovenliggende
commit
26d4cd855d
2 gewijzigde bestanden met toevoegingen van 111 en 40 verwijderingen
  1. 74 36
      examples/workbench_call_rpc_directly.ipynb
  2. 37 4
      src/petals/server/handler.py

File diff suppressed because it is too large
+ 74 - 36
examples/workbench_call_rpc_directly.ipynb


+ 37 - 4
src/petals/server/handler.py

@@ -592,7 +592,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
 
     @staticmethod
-    async def _wrap_input_stream(stream):
+    async def _read_until_eos(stream):
         while True:
             expert_request = await anext(stream)
             yield expert_request
@@ -605,10 +605,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         async with timeout(self.request_timeout):
-            wrapped_requests = self._wrap_input_stream(requests)
 
             # Parse requests and prepare backends
-            uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
+            uid_str, flat_inputs, metadata = await self._gather_inputs(self._read_until_eos(requests), context)
             requested_uids = self._check_uids(uid_str)
             self._log_request("rpc_forward_stream", requested_uids, context)
 
@@ -620,6 +619,8 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
+            print(f"{requested_backends=}, {active_adapter=}, {points=}, {args_structure=}")
+
             hidden_states = await run_rpc_forward(
                 *flat_inputs,
                 requested_backends=requested_backends,
@@ -632,4 +633,36 @@ class TransformerConnectionHandler(ConnectionHandler):
             for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
                     print("EOS")
-                    yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))
+                    yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"_EOS": True}))
+
+            
+            #### 
+            new_uid_str, flat_extra_inputs, extra_metadata = await self._gather_inputs(self._read_until_eos(requests), context)
+            backward_args_structure = extra_metadata.get("args_structure")
+            assert len(flat_extra_inputs) == 1
+            assert new_uid_str == uid_str
+            print("I solemnly swear to think about how to use extra_metadata for pushing when it comes to this")
+            grad_outputs, = flat_extra_inputs
+
+            print("HERE!")
+            
+            print("FLAT INPUTS", flat_inputs)
+            print("GRAD OUTPUTS", grad_outputs)
+            print(backward_args_structure)
+            
+            grads = await run_rpc_backward(
+                flat_inputs[0],
+                grad_outputs,
+                *flat_inputs[1:],
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
+                args_structure=backward_args_structure,
+            )
+
+            # Split the serialized_grad_inputs for streaming and respond
+            for tensor in self._serialize_grads(grads, requested_backends, metadata):
+                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
+                    print("SENDING GRADS:", part)
+                    yield runtime_pb2.ExpertResponse(tensors=[part])

Some files were not shown because too many files changed in this diff