|
@@ -362,7 +362,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
|
|
|
hidden_states = await _rpc_forward(
|
|
|
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
|
|
|
+ *flat_inputs,
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ points=points,
|
|
|
)
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
@@ -385,7 +389,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
hidden_states = await _rpc_forward(
|
|
|
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
|
|
|
+ *flat_inputs,
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ points=points,
|
|
|
)
|
|
|
|
|
|
# Split the serialized_output for streaming and respond to client
|
|
@@ -432,7 +440,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
|
|
|
grads = await _rpc_backward(
|
|
|
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
|
|
|
+ *flat_tensors,
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ points=points,
|
|
|
)
|
|
|
|
|
|
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
|
@@ -453,7 +465,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
grads = await _rpc_backward(
|
|
|
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points
|
|
|
+ *flat_tensors,
|
|
|
+ requested_backends=requested_backends,
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ points=points,
|
|
|
)
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
@@ -558,7 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async def _rpc_forward(
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
- active_adapter: str = '',
|
|
|
+ active_adapter: str = "",
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
) -> torch.Tensor:
|
|
@@ -590,7 +606,9 @@ async def _rpc_forward(
|
|
|
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
|
|
)
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
- hidden_states, active_adapter, priority=priority,
|
|
|
+ hidden_states,
|
|
|
+ active_adapter,
|
|
|
+ priority=priority,
|
|
|
)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
assert (
|
|
@@ -603,7 +621,7 @@ async def _rpc_forward(
|
|
|
async def _rpc_backward(
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
- active_adapter: str = '',
|
|
|
+ active_adapter: str = "",
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|