|
@@ -12,7 +12,7 @@ from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
|
from src.server.backend import MAX_LENGTH, TransformerBackend
|
|
|
-from src.utils.misc import DUMMY, is_dummy, is_dummy_batch, make_dummy_batch
|
|
|
+from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
@@ -128,11 +128,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
|
|
|
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
|
|
|
+ grad_inputs_schema_with_prompts = (
|
|
|
+ requested_backends[0].args_schema * len(grads),
|
|
|
+ requested_backends[0].kwargs_schema,
|
|
|
+ )
|
|
|
+
|
|
|
# Serialize the overall grad_input and respond
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=[
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
]
|
|
|
)
|
|
|
|
|
@@ -146,10 +153,17 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
|
|
|
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
|
|
|
+ grad_inputs_schema_with_prompts = (
|
|
|
+ requested_backends[0].args_schema * len(grads),
|
|
|
+ requested_backends[0].kwargs_schema,
|
|
|
+ )
|
|
|
+
|
|
|
# Serialize the overall grad_inputs
|
|
|
serialized_grad_inputs = [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
]
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
output_split = [
|
|
@@ -200,17 +214,20 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
async def _rpc_forward(inputs, requested_backends):
|
|
|
# Cast inputs to backend dtype
|
|
|
- hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
|
|
|
- assert len(hidden_states) == 2 and hidden_states[0].ndim == 3
|
|
|
- hidden_states, prompts = hidden_states
|
|
|
+ inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
|
|
|
+ assert len(inputs) == 2 and inputs[0].ndim == 3
|
|
|
+ hidden_states, prompts = inputs
|
|
|
|
|
|
if is_dummy(prompts):
|
|
|
- batch_size = hidden_states.shape[0]
|
|
|
- prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
+ else:
|
|
|
+ pre_seq_len = prompts.shape[2]
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
|
- (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, prompt)
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ hidden_states[:, :pre_seq_len] += prompt
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
assert (
|
|
|
hidden_states.ndim == 3
|
|
@@ -225,11 +242,11 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
prompts = prompts.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
- batch_size = inputs.shape[0]
|
|
|
|
|
|
if is_dummy(prompts):
|
|
|
- prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
else:
|
|
|
+ pre_seq_len = prompts.shape[2]
|
|
|
prompts = [p.squeeze(0) for p in prompts.split(1)]
|
|
|
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
@@ -237,19 +254,25 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
|
|
|
inter_inputs = [inputs]
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
- (inputs,) = await backend.forward_pool.submit_task(inputs, prompt)
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ inputs[:, :pre_seq_len] += prompt
|
|
|
+ (inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
|
grad_prompts = []
|
|
|
# Run a chain of requested backends
|
|
|
for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
|
|
|
- grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
|
|
|
- assert isinstance(grads, (list, tuple)) and len(grads) == 2
|
|
|
- grad_outputs, grad_prompt = grads
|
|
|
- grad_prompts.append(grad_prompt[None])
|
|
|
-
|
|
|
- is_dummy_grad_prompts = [is_dummy_batch(grad_param, batch_size) for grad_param in grad_prompts]
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ inp[:, :pre_seq_len] += prompt
|
|
|
+ (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
+ assert isinstance(grad_outputs, torch.Tensor)
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
|
|
|
+ else:
|
|
|
+ grad_prompts.append(DUMMY)
|
|
|
+
|
|
|
+ is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
|
|
|
grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
|
|
|
grads = [grad_outputs, grad_prompts]
|
|
|
return grads
|