|
@@ -137,12 +137,12 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO unfuck
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
|
|
|
grad_inputs_schema_with_prompts = (
|
|
|
requested_backends[0].args_schema * len(grads),
|
|
|
requested_backends[0].kwargs_schema,
|
|
|
- ) # TODO unfuck
|
|
|
+ ) # TODO generalize
|
|
|
|
|
|
# Serialize the overall grad_input and respond
|
|
|
return runtime_pb2.ExpertResponse(
|
|
@@ -163,11 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO unfuck
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
grad_inputs_schema_with_prompts = (
|
|
|
requested_backends[0].args_schema * len(grads),
|
|
|
requested_backends[0].kwargs_schema,
|
|
|
- ) # TODO unfuck
|
|
|
+ ) # TODO generalize
|
|
|
|
|
|
# Serialize the overall grad_inputs
|
|
|
serialized_grad_inputs = [
|