Your Name 1 سال پیش
والد
کامیت
465fd93147
3فایلهای تغییر یافته به همراه20 افزوده شده و 10 حذف شده
  1. 2 0
      src/petals/client/remote_forward_backward.py
  2. 5 2
      src/petals/client/sequential_autograd.py
  3. 13 8
      src/petals/server/block_functions.py

+ 2 - 0
src/petals/client/remote_forward_backward.py

@@ -144,6 +144,8 @@ async def run_remote_backward(
             for tensor, compression in zip(flat_tensors, codecs)
         )
     )
+    for tensor, serialized_tensor in zip(flat_tensors, serialized_tensors):
+        serialized_tensor.requires_grad = tensor.requires_grad
 
     size = sum(t.element_size() * t.nelement() for t in flat_tensors)
     backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary

+ 5 - 2
src/petals/client/sequential_autograd.py

@@ -161,7 +161,8 @@ async def sequential_backward(
                     span.peer_id,
                     sequence_manager.block_uids[span.start : span.end],
                     grad_outputs,
-                    *inputs,
+                    inputs,
+                    prompts[span.start: span.end],
                     *block_kwargs[span.start : span.end],
                 )
                 grad_outputs = [grad_outputs]
@@ -224,12 +225,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches)
         if prompts is None or is_dummy(prompts):
             prompt_batches = [DUMMY] * len(input_batches)
         else:
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
+            prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches)
 
-        sequence_manager.rpc_info  # lazy init
+        sequence_manager.rpc_info  # lazy init #TODO no longer needed
         outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
 

+ 13 - 8
src/petals/server/block_functions.py

@@ -88,11 +88,11 @@ async def run_rpc_backward(
     points: int = 0,
     args_structure: Any,
 ) -> Tuple[Sequence[torch.Tensor], Any]:
-    (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(
+    assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
+    ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
         requested_backends, flat_tensors, args_structure
     )
     # Cast inputs & grad outputs to backend dtype
-    assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
     num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
     hidden_states = hidden_states.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
@@ -117,8 +117,7 @@ async def run_rpc_backward(
         (hidden_states,) = await backend.forward_pool.submit_task(
             active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens
         )
-
-        assert isinstance(hidden_states, torch.Tensor)
+        assert isinstance(hidden_states, torch.Tensor), "intermediate hidden states is not a tensor"
 
     if not is_dummy(prompts[-1]):
         hidden_states[:, : prompts[-1].shape[1]] += prompts[-1]
@@ -129,13 +128,15 @@ async def run_rpc_backward(
     grad_backend_kwargs_reversed = []
 
     # Run a chain of requested backends
-    for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))):
+    for hidden_states, prompt, backend, kwargs in reversed(list(zip(
+            inter_inputs, prompts, requested_backends, backend_kwargs))):
         assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        hidden_states = hidden_states.detach().requires_grad_(True)
         priority = prioritizer.prioritize(
-            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+            hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
         (*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
-            active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens
+            active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
         )
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
@@ -252,10 +253,14 @@ async def iterate_rpc_inference(
 def _check_inputs(
     requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
 ):
+    if len(flat_tensors) == 3:  # backward compatibility for rpc_backward, remove after 2.3
+        if flat_tensors[0].requires_grad and not flat_tensors[1].requires_grad:
+            hidden_states, grad_outputs, prompts = flat_tensors
+            flat_tensors = grad_outputs, hidden_states, prompts
     if args_structure is not None:
         args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
     else:
-        args, *backend_kwargs = flat_tensors, {}  # backward compatibility
+        args, *backend_kwargs = flat_tensors, {}  # backward compatibility for grad structure, remove at 2.2
 
     if len(backend_kwargs) not in (1, len(requested_backends)):
         raise RuntimeError(