dbaranchuk 3 gadi atpakaļ
vecāks
revīzija
7d5348b10b
1 mainītis faili ar 2 papildinājumiem un 1 dzēšanām
  1. 2 1
      src/server/handler.py

+ 2 - 1
src/server/handler.py

@@ -96,7 +96,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         return runtime_pb2.ExpertResponse(
             tensors=[
-                serialize_torch_tensor(result.type(proto.dtype), proto.compression, allow_inplace=True)
+                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
             ]
         )
@@ -172,6 +172,7 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+
         uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
         inputs, grads = inputs_and_grads
         requested_uids = self._check_header_str(uids_header)