|
@@ -6,7 +6,7 @@ import torch
|
|
|
import hivemind
|
|
|
from hivemind.averaging import DecentralizedAverager
|
|
|
from hivemind.averaging.control import StepControl
|
|
|
-from hivemind.utils import DHTExpiration, get_logger
|
|
|
+from hivemind.utils import DHTExpiration, get_dht_time, get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -80,7 +80,7 @@ class GradientAverager(DecentralizedAverager):
|
|
|
if reuse_grad_buffers and accumulate_grads_on is not None:
|
|
|
logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
|
|
|
client_mode = client_mode if client_mode is not None else dht.client_mode
|
|
|
- self._parameters = tuple(parameters)
|
|
|
+ self.parameters = tuple(parameters)
|
|
|
self.reuse_grad_buffers = reuse_grad_buffers
|
|
|
self.warn = warn
|
|
|
self.local_samples_accumulated = 0
|
|
@@ -102,7 +102,7 @@ class GradientAverager(DecentralizedAverager):
|
|
|
|
|
|
def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
|
|
|
"""gradient buffers associated with parameters"""
|
|
|
- for param in self._parameters:
|
|
|
+ for param in self.parameters:
|
|
|
if param.grad is None:
|
|
|
param.grad = torch.zeros_like(param)
|
|
|
yield param.grad
|
|
@@ -152,6 +152,7 @@ class GradientAverager(DecentralizedAverager):
|
|
|
weight: Optional[float] = None,
|
|
|
reset_accumulators: bool = True,
|
|
|
control: Optional[StepControl] = None,
|
|
|
+ timeout: Optional[float] = None,
|
|
|
wait: bool = True,
|
|
|
**kwargs,
|
|
|
):
|
|
@@ -161,12 +162,13 @@ class GradientAverager(DecentralizedAverager):
|
|
|
:param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
|
|
|
:param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
|
|
|
:param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
|
|
|
+ :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
|
|
|
:param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
|
|
|
"""
|
|
|
if control is None:
|
|
|
- control = self.schedule_step(**kwargs)
|
|
|
+ control = self.schedule_step(timeout=timeout, **kwargs)
|
|
|
elif len(kwargs) > 0:
|
|
|
- RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
|
|
|
+ raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
|
|
|
assert not control.triggered, f"This {type(control)} instance was already used."
|
|
|
self._load_accumulators_into_averager_()
|
|
|
self._accumulators_used_in_step = True
|
|
@@ -175,9 +177,9 @@ class GradientAverager(DecentralizedAverager):
|
|
|
control.weight = self.local_samples_accumulated if weight is None else weight
|
|
|
if reset_accumulators:
|
|
|
self.reset_accumulated_grads_()
|
|
|
-
|
|
|
control.allow_allreduce()
|
|
|
- return control.result() if wait else control
|
|
|
+
|
|
|
+ return control.result(timeout) if wait else control
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def _load_accumulators_into_averager_(self):
|
|
@@ -209,11 +211,11 @@ class GradientAverager(DecentralizedAverager):
|
|
|
self._new_averaged_grads = False
|
|
|
with self.get_tensors() as averaged_grads:
|
|
|
try:
|
|
|
- assert len(averaged_grads) == len(self._parameters)
|
|
|
- old_grads = [param.grad for param in self._parameters]
|
|
|
- for param, new_grad in zip(self._parameters, averaged_grads):
|
|
|
+ assert len(averaged_grads) == len(self.parameters)
|
|
|
+ old_grads = [param.grad for param in self.parameters]
|
|
|
+ for param, new_grad in zip(self.parameters, averaged_grads):
|
|
|
param.grad = new_grad
|
|
|
yield
|
|
|
finally:
|
|
|
- for param, old_grad in zip(self._parameters, old_grads):
|
|
|
+ for param, old_grad in zip(self.parameters, old_grads):
|
|
|
param.grad = old_grad
|