|
@@ -12,6 +12,8 @@ from hivemind.averaging.control import AveragingStage, StepControl
|
|
from hivemind.compression import CompressionBase, NoCompression
|
|
from hivemind.compression import CompressionBase, NoCompression
|
|
from hivemind.dht import DHT
|
|
from hivemind.dht import DHT
|
|
from hivemind.optim.experimental.grad_averager import GradientAverager
|
|
from hivemind.optim.experimental.grad_averager import GradientAverager
|
|
|
|
+from hivemind.optim.experimental.power_ef_averager import PowerEFGradientAverager
|
|
|
|
+from hivemind.optim.experimental.power_sgd_averager import PowerSGDGradientAverager
|
|
from hivemind.optim.experimental.progress_tracker import ProgressTracker
|
|
from hivemind.optim.experimental.progress_tracker import ProgressTracker
|
|
from hivemind.optim.experimental.state_averager import (
|
|
from hivemind.optim.experimental.state_averager import (
|
|
LRSchedulerBase,
|
|
LRSchedulerBase,
|
|
@@ -187,11 +189,13 @@ class Optimizer(torch.optim.Optimizer):
|
|
client_mode: bool = None,
|
|
client_mode: bool = None,
|
|
auxiliary: bool = False,
|
|
auxiliary: bool = False,
|
|
grad_compression: CompressionBase = NoCompression(),
|
|
grad_compression: CompressionBase = NoCompression(),
|
|
|
|
+ grad_rank_averager: Optional[str] = None,
|
|
state_averaging_compression: CompressionBase = NoCompression(),
|
|
state_averaging_compression: CompressionBase = NoCompression(),
|
|
load_state_compression: CompressionBase = NoCompression(),
|
|
load_state_compression: CompressionBase = NoCompression(),
|
|
average_opt_statistics: Sequence[str] = (),
|
|
average_opt_statistics: Sequence[str] = (),
|
|
extra_tensors: Sequence[torch.Tensor] = (),
|
|
extra_tensors: Sequence[torch.Tensor] = (),
|
|
averager_opts: Optional[dict] = None,
|
|
averager_opts: Optional[dict] = None,
|
|
|
|
+ grad_averager_opts: Optional[dict] = None,
|
|
tracker_opts: Optional[dict] = None,
|
|
tracker_opts: Optional[dict] = None,
|
|
performance_ema_alpha: float = 0.1,
|
|
performance_ema_alpha: float = 0.1,
|
|
shutdown_timeout: float = 5,
|
|
shutdown_timeout: float = 5,
|
|
@@ -255,7 +259,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
)
|
|
)
|
|
if not use_local_updates:
|
|
if not use_local_updates:
|
|
self.grad_averager = self._make_gradient_averager(
|
|
self.grad_averager = self._make_gradient_averager(
|
|
- reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
|
|
|
|
|
|
+ reuse_grad_buffers=reuse_grad_buffers, grad_rank_averager=grad_rank_averager, compression=grad_compression, **grad_averager_opts or {}
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
self.grad_averager = None
|
|
self.grad_averager = None
|
|
@@ -289,9 +293,15 @@ class Optimizer(torch.optim.Optimizer):
|
|
**kwargs,
|
|
**kwargs,
|
|
)
|
|
)
|
|
|
|
|
|
- def _make_gradient_averager(self, **kwargs) -> GradientAverager:
|
|
|
|
|
|
+ def _make_gradient_averager(self, grad_rank_averager, **kwargs) -> GradientAverager:
|
|
assert hasattr(self, "state_averager"), "must initialize state averager first"
|
|
assert hasattr(self, "state_averager"), "must initialize state averager first"
|
|
- grad_averager = GradientAverager(
|
|
|
|
|
|
+ if grad_rank_averager == "power_ef":
|
|
|
|
+ grad_averager_type = PowerEFGradientAverager
|
|
|
|
+ elif grad_rank_averager == "power_sgd":
|
|
|
|
+ grad_averager_type = PowerSGDGradientAverager
|
|
|
|
+ else:
|
|
|
|
+ grad_averager_type = GradientAverager
|
|
|
|
+ grad_averager = grad_averager_type(
|
|
dht=self.dht,
|
|
dht=self.dht,
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
parameters=self.state_averager.main_parameters,
|
|
parameters=self.state_averager.main_parameters,
|