|
@@ -88,11 +88,11 @@ async def run_rpc_backward(
|
|
|
points: int = 0,
|
|
|
args_structure: Any,
|
|
|
) -> Tuple[Sequence[torch.Tensor], Any]:
|
|
|
- (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(
|
|
|
+ assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
|
|
|
+ ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
|
|
|
requested_backends, flat_tensors, args_structure
|
|
|
)
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
- assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
|
|
|
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
@@ -117,8 +117,7 @@ async def run_rpc_backward(
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens
|
|
|
)
|
|
|
-
|
|
|
- assert isinstance(hidden_states, torch.Tensor)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor), "intermediate hidden states is not a tensor"
|
|
|
|
|
|
if not is_dummy(prompts[-1]):
|
|
|
hidden_states[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
@@ -129,13 +128,15 @@ async def run_rpc_backward(
|
|
|
grad_backend_kwargs_reversed = []
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
- for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))):
|
|
|
+ for hidden_states, prompt, backend, kwargs in reversed(list(zip(
|
|
|
+ inter_inputs, prompts, requested_backends, backend_kwargs))):
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
+ hidden_states = hidden_states.detach().requires_grad_(True)
|
|
|
priority = prioritizer.prioritize(
|
|
|
- inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
+ hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
)
|
|
|
(*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
|
|
|
- active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
|
|
|
+ active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
|
|
|
)
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|
|
@@ -252,10 +253,14 @@ async def iterate_rpc_inference(
|
|
|
def _check_inputs(
|
|
|
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
|
|
|
):
|
|
|
+ if len(flat_tensors) == 3: # backward compatibility for rpc_backward, remove after 2.3
|
|
|
+ if flat_tensors[0].requires_grad and not flat_tensors[1].requires_grad:
|
|
|
+ hidden_states, grad_outputs, prompts = flat_tensors
|
|
|
+ flat_tensors = grad_outputs, hidden_states, prompts
|
|
|
if args_structure is not None:
|
|
|
args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
else:
|
|
|
- args, *backend_kwargs = flat_tensors, {} # backward compatibility
|
|
|
+ args, *backend_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
|
|
|
|
|
|
if len(backend_kwargs) not in (1, len(requested_backends)):
|
|
|
raise RuntimeError(
|