Bläddra i källkod

and now its black

justheuristic 3 år sedan
förälder
incheckning
512cadc34e
3 ändrade filer med 10 tillägg och 4 borttagningar
  1. 1 0
      hivemind/optim/__init__.py
  2. 8 4
      hivemind/optim/collaborative.py
  3. 1 0
      hivemind/optim/grad_scaler.py

+ 1 - 0
hivemind/optim/__init__.py

@@ -2,3 +2,4 @@ from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.grad_scaler import HivemindGradScaler

+ 8 - 4
hivemind/optim/collaborative.py

@@ -99,7 +99,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
     """
-    _step_supports_amp_scaling = True  # pytorch amp support
 
     def __init__(
         self,
@@ -150,6 +149,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
 
+        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
+
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
@@ -218,7 +219,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
-        assert grad_scaler is None or isinstance(grad_scaler, HivemindGradScaler)
+        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
         if self.batch_size_per_step is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -373,7 +375,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
         else:
             if self._grads is None:
-                self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
+                self._grads = [
+                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
+                ]
             yield from self._grads
 
     @torch.no_grad()
@@ -381,7 +385,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
             # user is responsible for accumulating gradients in .grad buffers
-            assert batch_size == self.batch_size_per_step, "Custom batch size is not implemented for reuse_grad_buffers"
+            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
         else:
             alpha = float(batch_size) / self.batch_size_per_step
             for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):

+ 1 - 0
hivemind/optim/grad_scaler.py

@@ -13,6 +13,7 @@ logger = get_logger(__name__)
 
 class HivemindGradScaler(TorchGradScaler):
     """A thin wrapper over GradScaler that supports hivemind-style training with CollaborativeOptimizer and others"""
+
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False