wrapper.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. class OptimizerWrapper(torch.optim.Optimizer):
  3. r"""
  4. A wrapper for pytorch.optimizer that forwards all methods to the wrapped optimizer
  5. """
  6. def __init__(self, optim: torch.optim.Optimizer):
  7. object.__init__(self)
  8. self.optim = optim
  9. @property
  10. def defaults(self):
  11. return self.optim.defaults
  12. @property
  13. def state(self):
  14. return self.optim.state
  15. def __getstate__(self):
  16. return self.optim.__getstate__()
  17. def __setstate__(self, state):
  18. self.optim.__setstate__(state)
  19. def __repr__(self):
  20. return f"{self.__class__.__name__}({repr(self.optim)})"
  21. def state_dict(self):
  22. return self.optim.state_dict()
  23. def load_state_dict(self, state_dict: dict) -> None:
  24. return self.optim.load_state_dict(state_dict)
  25. def step(self, *args, **kwargs):
  26. return self.optim.step(*args, **kwargs)
  27. def zero_grad(self, *args, **kwargs):
  28. return self.optim.zero_grad(*args, **kwargs)
  29. @property
  30. def param_groups(self):
  31. return self.optim.param_groups
  32. def add_param_group(self, param_group: dict) -> None:
  33. return self.optim.add_param_group(param_group)