base.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from typing import Any
  2. import torch
  3. from hivemind.dht import DHT
  4. class DecentralizedOptimizerBase(torch.optim.Optimizer):
  5. """ A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model """
  6. def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
  7. self.opt, self.dht = opt, dht
  8. @property
  9. def state(self):
  10. return self.opt.state
  11. @property
  12. def param_groups(self):
  13. return self.opt.param_groups
  14. def add_param_group(self, param_group: dict) -> None:
  15. raise ValueError(f"{self.__class__.__name__} does not support calling add_param_group after creation."
  16. f"Please provide all parameter groups at init.")
  17. def state_dict(self) -> dict:
  18. return self.opt.state_dict()
  19. def load_state_dict(self, state_dict: dict):
  20. return self.opt.load_state_dict(state_dict)
  21. def __repr__(self):
  22. return f"{self.__class__.__name__}(opt={repr(self.opt)}, dht={repr(self.dht)})"
  23. def shutdown(self):
  24. raise NotImplementedError()