|
@@ -1,6 +1,6 @@
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
|
|
|
+from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from async_timeout import timeout
|
|
@@ -202,14 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
hidden_states = await _rpc_forward(
|
|
|
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
-
|
|
|
- # Serialize output and respond to client
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
- tensors=[
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
- ]
|
|
|
+ tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
|
)
|
|
|
|
|
|
async def rpc_forward_stream(
|
|
@@ -230,22 +224,34 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
hidden_states = await _rpc_forward(
|
|
|
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
- assert (
|
|
|
- isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
- ), "hidden_states must be a 3d tensor"
|
|
|
-
|
|
|
- # Serialize the overall output
|
|
|
- serialized_output = [
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
- ]
|
|
|
|
|
|
# Split the serialized_output for streaming and respond to client
|
|
|
- output_split = [
|
|
|
- part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
- ]
|
|
|
- async for part in as_aiter(*output_split):
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
+ for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
+
|
|
|
+ def _serialize_outputs(
|
|
|
+ self,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ requested_backends: Sequence[TransformerBackend],
|
|
|
+ metadata: Dict[str, Any],
|
|
|
+ ) -> Sequence[runtime_pb2.Tensor]:
|
|
|
+ """Serialize forward outputs using either outputs_schema or custom user-specified schema"""
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
|
|
|
+ outputs_schema = requested_backends[-1].outputs_schema
|
|
|
+
|
|
|
+ if metadata.get("output_compression") is not None:
|
|
|
+ assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
+ output_compression = tuple(metadata["output_compression"])
|
|
|
+ assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
|
|
|
+ assert len(output_compression) == 1, f"output_compression tuple should have 1 element"
|
|
|
+ else:
|
|
|
+ output_compression = tuple(tensor.compression for tensor in outputs_schema)
|
|
|
+
|
|
|
+ return [
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
|
|
|
+ for result, proto, compression in zip([hidden_states], outputs_schema, output_compression)
|
|
|
+ ]
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
async with timeout(self.request_timeout):
|
|
@@ -265,21 +271,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
|
|
|
- # Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
-
|
|
|
- grad_inputs_schema_with_prompts = (
|
|
|
- requested_backends[0].args_schema * len(grads),
|
|
|
- requested_backends[0].kwargs_schema,
|
|
|
- ) # TODO generalize
|
|
|
-
|
|
|
- # Serialize the overall grad_input and respond
|
|
|
- return runtime_pb2.ExpertResponse(
|
|
|
- tensors=[
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
- ]
|
|
|
- )
|
|
|
+ return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
|
|
|
|
|
async def rpc_backward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
@@ -298,28 +290,38 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
grads = await _rpc_backward(
|
|
|
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
-
|
|
|
- # Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
- grad_inputs_schema_with_prompts = (
|
|
|
- requested_backends[0].args_schema * len(grads),
|
|
|
- requested_backends[0].kwargs_schema,
|
|
|
- ) # TODO generalize
|
|
|
-
|
|
|
- # Serialize the overall grad_inputs
|
|
|
- serialized_grad_inputs = [
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
- ]
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
- output_split = [
|
|
|
- part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
- ]
|
|
|
+ for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
|
+ for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
- async for part in as_aiter(*output_split):
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
-
|
|
|
- def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
|
|
|
+ def _serialize_grads(
|
|
|
+ self,
|
|
|
+ grads: Sequence[torch.Tensor],
|
|
|
+ requested_backends: Sequence[TransformerBackend],
|
|
|
+ metadata: Dict[str, Any],
|
|
|
+ ) -> Sequence[runtime_pb2.Tensor]:
|
|
|
+ """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
+ flat_grads_schema = tuple(
|
|
|
+ nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
|
|
|
+ ) # TODO generalize
|
|
|
+
|
|
|
+ if metadata.get("output_compression") is not None:
|
|
|
+ assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
+ output_compression = tuple(metadata["output_compression"])
|
|
|
+ assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
|
|
|
+ assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
|
|
|
+ else:
|
|
|
+ output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
|
|
|
+
|
|
|
+ return [
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
|
|
|
+ for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
|
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
|
uids = (uids or "").split(CHAIN_DELIMITER)
|
|
|
if not uids:
|
|
@@ -360,7 +362,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
yield handles
|
|
|
|
|
|
- def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
|
|
|
+ def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
|
|
|
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
|
|
|
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
|
|
|
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
|