|
@@ -202,9 +202,6 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
hidden_states = await _rpc_forward(
|
|
|
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
-
|
|
|
- # Serialize output and respond to client
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
|
)
|
|
@@ -227,16 +224,15 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
hidden_states = await _rpc_forward(
|
|
|
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
- serialized_outputs = self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
|
|
|
|
# Split the serialized_output for streaming and respond to client
|
|
|
- for tensor in serialized_outputs:
|
|
|
+ for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
def _serialize_outputs(
|
|
|
self,
|
|
|
- hidden_states: Sequence[torch.Tensor],
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
metadata: Dict[str, Any],
|
|
|
) -> Sequence[runtime_pb2.Tensor]:
|
|
@@ -248,15 +244,13 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
output_compressions = tuple(metadata["output_compressions"])
|
|
|
assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
|
|
|
- assert len(output_compressions) == len(
|
|
|
- hidden_states
|
|
|
- ), f"output_compression should have {len(hidden_states)} elements"
|
|
|
+ assert len(output_compressions) == 1, f"output_compression tuple should have 1 element"
|
|
|
else:
|
|
|
output_compressions = tuple(tensor.compression for tensor in outputs_schema)
|
|
|
|
|
|
return [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
|
|
|
- for result, proto, compression in zip(hidden_states, outputs_schema, output_compressions)
|
|
|
+ for result, proto, compression in zip([hidden_states], outputs_schema, output_compressions)
|
|
|
]
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
@@ -296,9 +290,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
grads = await _rpc_backward(
|
|
|
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
- serialized_grad_inputs = self._serialize_grads(grads, requested_backends, metadata)
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
- for tensor in serialized_grad_inputs:
|
|
|
+ for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
@@ -308,7 +301,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
metadata: Dict[str, Any],
|
|
|
) -> Sequence[runtime_pb2.Tensor]:
|
|
|
- """Serialize gradients w.r.t. inputs using either backward schema or custom user-specified schema"""
|
|
|
+ """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
flat_grads_schema = tuple(
|
|
@@ -328,7 +321,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
for result, proto, compression in zip(grads, flat_grads_schema, output_compressions)
|
|
|
]
|
|
|
|
|
|
- def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
|
|
|
+ def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
|
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
|
uids = (uids or "").split(CHAIN_DELIMITER)
|
|
|
if not uids:
|