|
@@ -148,7 +148,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
|
|
|
|
|
|
:param grad_compression: compression strategy used for averaging gradients, default = no compression
|
|
|
- :param grad_averager: if provided, creates gradient averager with required averaging strategy
|
|
|
+ :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
|
|
|
:param state_averaging_compression: compression for averaging params and state tensors, default = no compression
|
|
|
:param load_state_compression: compression strategy for loading state from peers, default = no compression
|
|
|
:param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
|
|
@@ -230,7 +230,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
|
|
|
assert (
|
|
|
grad_averager_factory is None
|
|
|
- ), "if local_updates is True, provided gradient_averager will not be used"
|
|
|
+ ), "if local_updates is True, provided grad_averager_factory will not be used"
|
|
|
|
|
|
self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
|
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|