|
@@ -17,6 +17,7 @@ from petals.data_structures import InferenceMetadata
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
from petals.utils.misc import is_dummy
|
|
|
+from petals.utils.packaging import pack_args_kwargs
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -112,14 +113,17 @@ class TransformerBackend(ModuleBackend):
|
|
|
def backward(
|
|
|
self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
+ args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args]
|
|
|
+ # TODO remove this WITHIN PR#467; make sure args are passed properly and retain requires_grad
|
|
|
assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
|
|
|
with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
|
|
|
(outputs,) = self.module(*args, **kwargs)
|
|
|
assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
|
|
|
torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
|
|
|
- return nested_map(
|
|
|
- lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)
|
|
|
- )
|
|
|
+ # flat_tensors, structure = pack_args_kwargs(nested_map(
|
|
|
+ # lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs))
|
|
|
+ # )
|
|
|
+ return (args[0].grad,) # TODO pass additional kwarg-grads back to the user WITHIN #467
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def inference_step(
|