|
@@ -251,7 +251,6 @@ async def _rpc_backward(
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
inputs = inputs.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
- NO_PROMPTS = not prompts
|
|
|
prompts = prompts.to(requested_backends[0].dtype) if prompts else DUMMY
|
|
|
|
|
|
if is_dummy(prompts):
|
|
@@ -266,11 +265,10 @@ async def _rpc_backward(
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
if not is_dummy(prompt):
|
|
|
- inputs = inputs.clone() # TODO
|
|
|
inputs[:, :pre_seq_len] += prompt
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
- inter_inputs.append(inputs)
|
|
|
+ inter_inputs.append(inputs.clone()) #TODO optimize: reduce the number of copies
|
|
|
|
|
|
grad_prompts = []
|
|
|
# Run a chain of requested backends
|
|
@@ -281,9 +279,6 @@ async def _rpc_backward(
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|
|
|
grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
|
|
|
- else:
|
|
|
- grad_prompts.append(DUMMY)
|
|
|
|
|
|
- is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
|
|
|
- grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
|
|
|
- return [grad_outputs] if NO_PROMPTS else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
|
|
+ grad_prompts = torch.cat(grad_prompts, dim=0) if grad_prompts else DUMMY
|
|
|
+ return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|