Browse Source

black-isort

Your Name 2 years ago
parent
commit
fb9b21132c
1 changed files with 7 additions and 3 deletions
  1. 7 3
      src/petals/server/backend.py

+ 7 - 3
src/petals/server/backend.py

@@ -17,6 +17,7 @@ from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
+from petals.utils.packaging import pack_args_kwargs
 
 logger = get_logger(__name__)
 
@@ -112,14 +113,17 @@ class TransformerBackend(ModuleBackend):
     def backward(
         self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
     ) -> Tuple[torch.Tensor, ...]:
+        args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args]
+        # TODO remove this WITHIN PR#467; make sure args are passed properly and retain requires_grad
         assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
         with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
             (outputs,) = self.module(*args, **kwargs)
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
             torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
-        return nested_map(
-            lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)
-        )
+        # flat_tensors, structure = pack_args_kwargs(nested_map(
+        #     lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs))
+        # )
+        return (args[0].grad,)  # TODO pass additional kwarg-grads back to the user WITHIN #467
 
     @torch.inference_mode()
     def inference_step(