Bläddra i källkod

fix data structure

justheuristic 2 år sedan
förälder
incheckning
440cb4b2a4
1 ändrade filer med 7 tillägg och 14 borttagningar
  1. 7 14
      src/server/handler.py

+ 7 - 14
src/server/handler.py

@@ -202,9 +202,6 @@ class TransformerConnectionHandler(ConnectionHandler):
             hidden_states = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-            assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
-
-            # Serialize output and respond to client
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
             )
@@ -227,16 +224,15 @@ class TransformerConnectionHandler(ConnectionHandler):
             hidden_states = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-            serialized_outputs = self._serialize_outputs(hidden_states, requested_backends, metadata)
 
             # Split the serialized_output for streaming and respond to client
-            for tensor in serialized_outputs:
+            for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
                     yield runtime_pb2.ExpertResponse(tensors=[part])
 
     def _serialize_outputs(
         self,
-        hidden_states: Sequence[torch.Tensor],
+        hidden_states: torch.Tensor,
         requested_backends: Sequence[TransformerBackend],
         metadata: Dict[str, Any],
     ) -> Sequence[runtime_pb2.Tensor]:
@@ -248,15 +244,13 @@ class TransformerConnectionHandler(ConnectionHandler):
             assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
             output_compressions = tuple(metadata["output_compressions"])
             assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
-            assert len(output_compressions) == len(
-                hidden_states
-            ), f"output_compression should have {len(hidden_states)} elements"
+            assert len(output_compressions) == 1, f"output_compression tuple should have 1 element"
         else:
             output_compressions = tuple(tensor.compression for tensor in outputs_schema)
 
         return [
             serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
-            for result, proto, compression in zip(hidden_states, outputs_schema, output_compressions)
+            for result, proto, compression in zip([hidden_states], outputs_schema, output_compressions)
         ]
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
@@ -296,9 +290,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             grads = await _rpc_backward(
                 *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-            serialized_grad_inputs = self._serialize_grads(grads, requested_backends, metadata)
             # Split the serialized_grad_inputs for streaming and respond
-            for tensor in serialized_grad_inputs:
+            for tensor in self._serialize_grads(grads, requested_backends, metadata):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
                     yield runtime_pb2.ExpertResponse(tensors=[part])
 
@@ -308,7 +301,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_backends: Sequence[TransformerBackend],
         metadata: Dict[str, Any],
     ) -> Sequence[runtime_pb2.Tensor]:
-        """Serialize gradients w.r.t. inputs using either backward schema or custom user-specified schema"""
+        """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
         flat_grads_schema = tuple(
@@ -328,7 +321,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             for result, proto, compression in zip(grads, flat_grads_schema, output_compressions)
         ]
 
-    def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
+    def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
         """Check that the first request to rpc_inference is valid"""
         uids = (uids or "").split(CHAIN_DELIMITER)
         if not uids: