base.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from warnings import warn
  2. import torch
  3. from hivemind.dht import DHT
  4. class DecentralizedOptimizerBase(torch.optim.Optimizer):
  5. """A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model"""
  6. def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
  7. self.opt, self.dht = opt, dht
  8. warn(
  9. "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
  10. "in hivemind 1.1.0. Use hivemind.Optimizer instead",
  11. FutureWarning,
  12. stacklevel=2,
  13. )
  14. @property
  15. def state(self):
  16. return self.opt.state
  17. @property
  18. def param_groups(self):
  19. return self.opt.param_groups
  20. def add_param_group(self, param_group: dict) -> None:
  21. raise ValueError(
  22. f"{self.__class__.__name__} does not support calling add_param_group after creation."
  23. f"Please provide all parameter groups at init."
  24. )
  25. def state_dict(self) -> dict:
  26. return self.opt.state_dict()
  27. def load_state_dict(self, state_dict: dict):
  28. return self.opt.load_state_dict(state_dict)
  29. def __repr__(self):
  30. return f"{self.__class__.__name__}(opt={repr(self.opt)}, dht={repr(self.dht)})"
  31. def shutdown(self):
  32. raise NotImplementedError()