adaptive.py 1.1 KB

12345678910111213141516171819202122232425
  1. from typing import Sequence
  2. import torch.optim
  3. from hivemind.optim.collaborative import CollaborativeOptimizer
  4. from hivemind import TrainingAverager
  5. class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
  6. """
  7. Behaves exactly as CollaborativeOptimizer except:
  8. - averages adaptive learning rates of an optimizer
  9. - doesn't average gradients
  10. :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
  11. :param kwargs: options for CollaborativeOptimizer
  12. """
  13. def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
  14. super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
  15. def _make_averager(self, average_opt_statistics, **kwargs):
  16. return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=False,
  17. average_opt_statistics=average_opt_statistics,
  18. prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
  19. listen=not self.client_mode, **kwargs)