|
@@ -121,6 +121,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
hidden_states, prompts, hypo_ids = [
|
|
|
deserialize_torch_tensor(tensor) for tensor in request.tensors
|
|
|
]
|
|
|
+ initial_hidden_states = hidden_states.clone()
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
@@ -169,14 +170,10 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
cache_metadata, hidden_states, hypo_ids, priority=priority
|
|
|
)
|
|
|
|
|
|
- # serialize and send last layer outputs
|
|
|
+ # serialize and send last layer outputs without the residual component
|
|
|
+ outputs = hidden_states - initial_hidden_states
|
|
|
yield runtime_pb2.ExpertResponse(
|
|
|
- tensors=[
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(
|
|
|
- (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
- )
|
|
|
- ]
|
|
|
+ tensors=self._serialize_outputs(outputs, requested_backends, metadata)
|
|
|
)
|
|
|
|
|
|
# prepare for next step
|
|
@@ -199,11 +196,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
|
|
|
- hidden_states = await _rpc_forward(
|
|
|
+ outputs = await _rpc_forward(
|
|
|
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
- tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
|
+ tensors=self._serialize_outputs(outputs, requested_backends, metadata)
|
|
|
)
|
|
|
|
|
|
async def rpc_forward_stream(
|
|
@@ -221,12 +218,12 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- hidden_states = await _rpc_forward(
|
|
|
+ outputs = await _rpc_forward(
|
|
|
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
)
|
|
|
|
|
|
# Split the serialized_output for streaming and respond to client
|
|
|
- for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
+ for tensor in self._serialize_outputs(outputs, requested_backends, metadata):
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
@@ -379,12 +376,15 @@ async def _rpc_forward(
|
|
|
points: int = 0,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
- Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
+ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream.
|
|
|
+ A forward pass computes transformer hidden states after applying all requested_backends without the residual part.
|
|
|
+ In other words, it returns the last hidden states minus the first hidden states provided by the user.
|
|
|
|
|
|
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
|
|
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
- :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
+ :returns: hidden states after the last layer *without residuals*, [batch_size, seq_length, hid_size]
|
|
|
+ :note: this method returns (layerN(...layer1(inputs) - inputs) to reduce compression error
|
|
|
"""
|
|
|
hidden_states, prompts = flat_tensors
|
|
|
dtype = requested_backends[0].dtype
|
|
@@ -395,6 +395,7 @@ async def _rpc_forward(
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
else:
|
|
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
+ initial_hidden_states = hidden_states.clone()
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
@@ -405,17 +406,14 @@ async def _rpc_forward(
|
|
|
priority = prioritizer.prioritize(
|
|
|
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
|
|
)
|
|
|
- (hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
- hidden_states,
|
|
|
- priority=priority,
|
|
|
- )
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, priority=priority)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
assert (
|
|
|
hidden_states.ndim == 3
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
|
- # Serialize the overall output
|
|
|
- return hidden_states
|
|
|
+ # Return the difference between last hidden states and input hidden states (remove the residual component)
|
|
|
+ return torch.sub(hidden_states, initial_hidden_states, out=initial_hidden_states)
|
|
|
|
|
|
|
|
|
async def _rpc_backward(
|
|
@@ -424,10 +422,15 @@ async def _rpc_backward(
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
+ """
|
|
|
+ Backpropagate gradients through _rpc_forward, return gradients w.r.t. inputs and optional prompts without residuals
|
|
|
+ :note: like in rpc_forward, this method returns (grad_input - grad_output) for better compression
|
|
|
+ """
|
|
|
inputs, grad_outputs, prompts = flat_tensors
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
+ original_grad_outputs = grad_outputs.clone()
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
@@ -469,4 +472,5 @@ async def _rpc_backward(
|
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
|
- return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
|
|
+ grad_inputs = torch.sub(grad_prompts, original_grad_outputs, out=original_grad_outputs) # remove residuals
|
|
|
+ return [grad_inputs] if is_dummy(grad_prompts) else [grad_inputs, grad_prompts] # TODO un-duct-tape
|