|
@@ -592,7 +592,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
|
|
|
|
|
@staticmethod
|
|
|
- async def _wrap_input_stream(stream):
|
|
|
+ async def _read_until_eos(stream):
|
|
|
while True:
|
|
|
expert_request = await anext(stream)
|
|
|
yield expert_request
|
|
@@ -605,10 +605,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
async with timeout(self.request_timeout):
|
|
|
- wrapped_requests = self._wrap_input_stream(requests)
|
|
|
|
|
|
# Parse requests and prepare backends
|
|
|
- uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
|
|
|
+ uid_str, flat_inputs, metadata = await self._gather_inputs(self._read_until_eos(requests), context)
|
|
|
requested_uids = self._check_uids(uid_str)
|
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
@@ -620,6 +619,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
+ print(f"{requested_backends=}, {active_adapter=}, {points=}, {args_structure=}")
|
|
|
+
|
|
|
hidden_states = await run_rpc_forward(
|
|
|
*flat_inputs,
|
|
|
requested_backends=requested_backends,
|
|
@@ -632,4 +633,36 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
print("EOS")
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"_EOS": True}))
|
|
|
+
|
|
|
+
|
|
|
+ ####
|
|
|
+ new_uid_str, flat_extra_inputs, extra_metadata = await self._gather_inputs(self._read_until_eos(requests), context)
|
|
|
+ backward_args_structure = extra_metadata.get("args_structure")
|
|
|
+ assert len(flat_extra_inputs) == 1
|
|
|
+ assert new_uid_str == uid_str
|
|
|
+ print("I solemnly swear to think about how to use extra_metadata for pushing when it comes to this")
|
|
|
+ grad_outputs, = flat_extra_inputs
|
|
|
+
|
|
|
+ print("HERE!")
|
|
|
+
|
|
|
+ print("FLAT INPUTS", flat_inputs)
|
|
|
+ print("GRAD OUTPUTS", grad_outputs)
|
|
|
+ print(backward_args_structure)
|
|
|
+
|
|
|
+ grads = await run_rpc_backward(
|
|
|
+ flat_inputs[0],
|
|
|
+ grad_outputs,
|
|
|
+ *flat_inputs[1:],
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ points=points,
|
|
|
+ args_structure=backward_args_structure,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Split the serialized_grad_inputs for streaming and respond
|
|
|
+ for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
+ print("SENDING GRADS:", part)
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|