|
@@ -137,11 +137,12 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO unfuck
|
|
|
+
|
|
|
grad_inputs_schema_with_prompts = (
|
|
|
requested_backends[0].args_schema * len(grads),
|
|
|
requested_backends[0].kwargs_schema,
|
|
|
- )
|
|
|
+ ) # TODO unfuck
|
|
|
|
|
|
# Serialize the overall grad_input and respond
|
|
|
return runtime_pb2.ExpertResponse(
|
|
@@ -162,11 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO unfuck
|
|
|
grad_inputs_schema_with_prompts = (
|
|
|
requested_backends[0].args_schema * len(grads),
|
|
|
requested_backends[0].kwargs_schema,
|
|
|
- )
|
|
|
+ ) # TODO unfuck
|
|
|
|
|
|
# Serialize the overall grad_inputs
|
|
|
serialized_grad_inputs = [
|
|
@@ -245,11 +246,13 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
-async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_outputs: torch.Tensor, requested_backends):
|
|
|
+async def _rpc_backward(*flat_tensors: torch.Tensor, requested_backends):
|
|
|
+ inputs, grad_outputs, *prompts = flat_tensors
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
- prompts = prompts.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
+ NO_PROMPTS = not prompts
|
|
|
+ prompts = prompts.to(requested_backends[0].dtype) if prompts else DUMMY
|
|
|
|
|
|
if is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
@@ -283,5 +286,4 @@ async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_output
|
|
|
|
|
|
is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
|
|
|
grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
|
|
|
- grads = [grad_outputs, grad_prompts]
|
|
|
- return grads
|
|
|
+ return [grad_outputs] if NO_PROMPTS else [grad_outputs, grad_prompts] # TODO un-duct-tape
|