adaptive.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334
  1. from typing import Sequence
  2. import torch.optim
  3. from hivemind.optim.collaborative import CollaborativeOptimizer
  4. from hivemind.optim.training_averager import TrainingAverager
  5. class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
  6. """
  7. Behaves exactly as CollaborativeOptimizer except:
  8. * averages adaptive learning rates of an optimizer
  9. * doesn't average gradients
  10. :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
  11. :param kwargs: options for CollaborativeOptimizer
  12. """
  13. def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
  14. super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
  15. def _make_averager(self, average_opt_statistics, **kwargs):
  16. return TrainingAverager(
  17. self.opt,
  18. dht=self.dht,
  19. average_parameters=True,
  20. average_gradients=False,
  21. average_opt_statistics=average_opt_statistics,
  22. prefix=f"{self.prefix}_averaging",
  23. allreduce_timeout=self.averaging_timeout,
  24. client_mode=self.client_mode,
  25. **kwargs,
  26. )