Selaa lähdekoodia

Fix internal assert in GradientAverager (#410)

 Fix minor bugs in GradientAverager
justheuristic 3 vuotta sitten
vanhempi
commit
1cfc6a3b7b
1 muutettua tiedostoa jossa 13 lisäystä ja 11 poistoa
  1. 13 11
      hivemind/optim/experimental/grad_averager.py

+ 13 - 11
hivemind/optim/experimental/grad_averager.py

@@ -6,7 +6,7 @@ import torch
 import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_logger
+from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -80,7 +80,7 @@ class GradientAverager(DecentralizedAverager):
         if reuse_grad_buffers and accumulate_grads_on is not None:
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         client_mode = client_mode if client_mode is not None else dht.client_mode
-        self._parameters = tuple(parameters)
+        self.parameters = tuple(parameters)
         self.reuse_grad_buffers = reuse_grad_buffers
         self.warn = warn
         self.local_samples_accumulated = 0
@@ -102,7 +102,7 @@ class GradientAverager(DecentralizedAverager):
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
         """gradient buffers associated with parameters"""
-        for param in self._parameters:
+        for param in self.parameters:
             if param.grad is None:
                 param.grad = torch.zeros_like(param)
             yield param.grad
@@ -152,6 +152,7 @@ class GradientAverager(DecentralizedAverager):
         weight: Optional[float] = None,
         reset_accumulators: bool = True,
         control: Optional[StepControl] = None,
+        timeout: Optional[float] = None,
         wait: bool = True,
         **kwargs,
     ):
@@ -161,12 +162,13 @@ class GradientAverager(DecentralizedAverager):
         :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
         :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
         :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
         :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
         """
         if control is None:
-            control = self.schedule_step(**kwargs)
+            control = self.schedule_step(timeout=timeout, **kwargs)
         elif len(kwargs) > 0:
-            RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
+            raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
         assert not control.triggered, f"This {type(control)} instance was already used."
         self._load_accumulators_into_averager_()
         self._accumulators_used_in_step = True
@@ -175,9 +177,9 @@ class GradientAverager(DecentralizedAverager):
         control.weight = self.local_samples_accumulated if weight is None else weight
         if reset_accumulators:
             self.reset_accumulated_grads_()
-
         control.allow_allreduce()
-        return control.result() if wait else control
+
+        return control.result(timeout) if wait else control
 
     @torch.no_grad()
     def _load_accumulators_into_averager_(self):
@@ -209,11 +211,11 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
             try:
-                assert len(averaged_grads) == len(self._parameters)
-                old_grads = [param.grad for param in self._parameters]
-                for param, new_grad in zip(self._parameters, averaged_grads):
+                assert len(averaged_grads) == len(self.parameters)
+                old_grads = [param.grad for param in self.parameters]
+                for param, new_grad in zip(self.parameters, averaged_grads):
                     param.grad = new_grad
                 yield
             finally:
-                for param, old_grad in zip(self._parameters, old_grads):
+                for param, old_grad in zip(self.parameters, old_grads):
                     param.grad = old_grad