|
@@ -298,7 +298,10 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
requested_backends: Sequence[TransformerBackend],
|
|
requested_backends: Sequence[TransformerBackend],
|
|
metadata: Dict[str, Any],
|
|
metadata: Dict[str, Any],
|
|
) -> Sequence[runtime_pb2.Tensor]:
|
|
) -> Sequence[runtime_pb2.Tensor]:
|
|
- """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
|
|
|
|
|
|
+ """
|
|
|
|
+ Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema
|
|
|
|
+ :note: this code expects grads to be gradients w.r.t. inputs without residuals (as returned by rpc_backward)
|
|
|
|
+ """
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
flat_grads_schema = tuple(
|
|
flat_grads_schema = tuple(
|
|
@@ -472,5 +475,5 @@ async def _rpc_backward(
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
|
|
|
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
|
- grad_inputs = torch.sub(grad_prompts, original_grad_outputs, out=original_grad_outputs) # remove residuals
|
|
|
|
|
|
+ grad_inputs = torch.sub(grad_outputs, original_grad_outputs, out=original_grad_outputs) # remove residuals
|
|
return [grad_inputs] if is_dummy(grad_prompts) else [grad_inputs, grad_prompts] # TODO un-duct-tape
|
|
return [grad_inputs] if is_dummy(grad_prompts) else [grad_inputs, grad_prompts] # TODO un-duct-tape
|