|
@@ -11,7 +11,7 @@ import torch
|
|
|
from hivemind.averaging.control import AveragingStage, StepControl
|
|
|
from hivemind.compression import CompressionBase, NoCompression
|
|
|
from hivemind.dht import DHT
|
|
|
-from hivemind.optim.grad_averager import GradientAverager
|
|
|
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
|
|
|
from hivemind.optim.grad_scaler import GradScaler
|
|
|
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
|
|
|
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
|
|
@@ -189,7 +189,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
client_mode: bool = None,
|
|
|
auxiliary: bool = False,
|
|
|
grad_compression: CompressionBase = NoCompression(),
|
|
|
- grad_averager: Optional[Callable[..., GradientAverager]] = GradientAverager,
|
|
|
+ grad_averager_factory: Optional[GradientAveragerFactory] = GradientAverager,
|
|
|
state_averaging_compression: CompressionBase = NoCompression(),
|
|
|
load_state_compression: CompressionBase = NoCompression(),
|
|
|
average_opt_statistics: Sequence[str] = (),
|
|
@@ -228,7 +228,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
if use_local_updates:
|
|
|
assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
|
|
|
assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
|
|
|
- assert grad_averager is None, "if local_updates is True, provided gradient_averager will not be used"
|
|
|
+ assert (
|
|
|
+ grad_averager_factory is None
|
|
|
+ ), "if local_updates is True, provided gradient_averager will not be used"
|
|
|
|
|
|
self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
|
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|
|
@@ -259,9 +261,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
extra_tensors=extra_tensors,
|
|
|
**averager_opts or {},
|
|
|
)
|
|
|
- if grad_averager is not None and not use_local_updates:
|
|
|
+ if grad_averager_factory is not None and not use_local_updates:
|
|
|
self.grad_averager = self._make_gradient_averager(
|
|
|
- reuse_grad_buffers=reuse_grad_buffers, grad_averager=grad_averager
|
|
|
+ reuse_grad_buffers=reuse_grad_buffers, grad_averager_factory=grad_averager_factory
|
|
|
)
|
|
|
else:
|
|
|
self.grad_averager = None
|
|
@@ -294,9 +296,9 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
- def _make_gradient_averager(self, grad_averager, **kwargs) -> GradientAverager:
|
|
|
+ def _make_gradient_averager(self, grad_averager_factory, **kwargs) -> GradientAverager:
|
|
|
assert hasattr(self, "state_averager"), "must initialize state averager first"
|
|
|
- grad_averager = grad_averager(
|
|
|
+ grad_averager = grad_averager_factory(
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
|
parameters=self.state_averager.main_parameters,
|
|
@@ -409,7 +411,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self._maybe_schedule_state_averaging()
|
|
|
|
|
|
else:
|
|
|
- # grad_averager=None: update parameters on every step independently of other peers
|
|
|
+ # use_local_updates=True: update parameters on every step independently of other peers
|
|
|
if not self.auxiliary:
|
|
|
if grad_scaler is not None:
|
|
|
with grad_scaler.running_global_step():
|