Parcourir la source

option to serialize outputs with user-defined compression

justheuristic il y a 2 ans
Parent
commit
1958d1b5cf
1 fichiers modifiés avec 61 ajouts et 52 suppressions
  1. 61 52
      src/server/handler.py

+ 61 - 52
src/server/handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import contextlib
-from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
 
 import torch
 from async_timeout import timeout
@@ -206,10 +206,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             # Serialize output and respond to client
             return runtime_pb2.ExpertResponse(
-                tensors=[
-                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                    for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
-                ]
+                tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
             )
 
     async def rpc_forward_stream(
@@ -230,22 +227,37 @@ class TransformerConnectionHandler(ConnectionHandler):
             hidden_states = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-            assert (
-                isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
-            ), "hidden_states must be a 3d tensor"
-
-            # Serialize the overall output
-            serialized_output = [
-                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
-            ]
+            serialized_outputs = self._serialize_outputs(hidden_states, requested_backends, metadata)
 
             # Split the serialized_output for streaming and respond to client
-            output_split = [
-                part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-            ]
-            async for part in as_aiter(*output_split):
-                yield runtime_pb2.ExpertResponse(tensors=[part])
+            for tensor in serialized_outputs:
+                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
+                    yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    def _serialize_outputs(
+        self,
+        hidden_states: Sequence[torch.Tensor],
+        requested_backends: Sequence[TransformerBackend],
+        metadata: Dict[str, Any],
+    ) -> Sequence[runtime_pb2.Tensor]:
+        """Serialize forward outputs using either outputs_schema or custom user-specified schema"""
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
+        outputs_schema = requested_backends[-1].outputs_schema
+
+        if metadata.get("output_compressions") is not None:
+            assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
+            output_compressions = tuple(metadata["output_compressions"])
+            assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
+            assert len(output_compressions) == len(
+                hidden_states
+            ), f"output_compression should have {len(hidden_states)} elements"
+        else:
+            output_compressions = tuple(tensor.compression for tensor in outputs_schema)
+
+        return [
+            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
+            for result, proto, compression in zip(hidden_states, outputs_schema, output_compressions)
+        ]
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         async with timeout(self.request_timeout):
@@ -265,21 +277,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
 
-            # Modify grad_inputs_schema to support grad_prompts
-            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
-
-            grad_inputs_schema_with_prompts = (
-                requested_backends[0].args_schema * len(grads),
-                requested_backends[0].kwargs_schema,
-            )  # TODO generalize
-
-            # Serialize the overall grad_input and respond
-            return runtime_pb2.ExpertResponse(
-                tensors=[
-                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                    for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
-                ]
-            )
+            return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
 
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -298,26 +296,37 @@ class TransformerConnectionHandler(ConnectionHandler):
             grads = await _rpc_backward(
                 *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-
-            # Modify grad_inputs_schema to support grad_prompts
-            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
-            grad_inputs_schema_with_prompts = (
-                requested_backends[0].args_schema * len(grads),
-                requested_backends[0].kwargs_schema,
-            )  # TODO generalize
-
-            # Serialize the overall grad_inputs
-            serialized_grad_inputs = [
-                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
-            ]
+            serialized_grad_inputs = self._serialize_grads(grads, requested_backends, metadata)
             # Split the serialized_grad_inputs for streaming and respond
-            output_split = [
-                part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-            ]
+            for tensor in serialized_grad_inputs:
+                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
+                    yield runtime_pb2.ExpertResponse(tensors=[part])
 
-            async for part in as_aiter(*output_split):
-                yield runtime_pb2.ExpertResponse(tensors=[part])
+    def _serialize_grads(
+        self,
+        grads: Sequence[torch.Tensor],
+        requested_backends: Sequence[TransformerBackend],
+        metadata: Dict[str, Any],
+    ) -> Sequence[runtime_pb2.Tensor]:
+        """Serialize gradients w.r.t. inputs using either backward schema or custom user-specified schema"""
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+        flat_grads_schema = tuple(
+            nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
+        )  # TODO generalize
+
+        if metadata.get("output_compressions") is not None:
+            assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
+            output_compressions = tuple(metadata["output_compressions"])
+            assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
+            assert len(output_compressions) == len(grads), f"output_compression should have {len(grads)} elements"
+        else:
+            output_compressions = tuple(tensor.compression for tensor in flat_grads_schema)
+
+        return [
+            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
+            for result, proto, compression in zip(grads, flat_grads_schema, output_compressions)
+        ]
 
     def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""