Browse Source

and now its black

justheuristic 3 years ago
parent
commit
512cadc34e
3 changed files with 10 additions and 4 deletions
  1. 1 0
      hivemind/optim/__init__.py
  2. 8 4
      hivemind/optim/collaborative.py
  3. 1 0
      hivemind/optim/grad_scaler.py

+ 1 - 0
hivemind/optim/__init__.py

@@ -2,3 +2,4 @@ from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.grad_scaler import HivemindGradScaler

+ 8 - 4
hivemind/optim/collaborative.py

@@ -99,7 +99,6 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
     :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
       explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
     """
     """
-    _step_supports_amp_scaling = True  # pytorch amp support
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -150,6 +149,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.averager = self._make_averager(**kwargs)
         self.averager = self._make_averager(**kwargs)
 
 
+        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
+
         self.training_progress_key = f"{self.prefix}_progress"
         self.training_progress_key = f"{self.prefix}_progress"
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
         self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
         self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
@@ -218,7 +219,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
         """
-        assert grad_scaler is None or isinstance(grad_scaler, HivemindGradScaler)
+        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
         if self.batch_size_per_step is None:
         if self.batch_size_per_step is None:
             if batch_size is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -373,7 +375,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             return
             return
         else:
         else:
             if self._grads is None:
             if self._grads is None:
-                self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
+                self._grads = [
+                    torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()
+                ]
             yield from self._grads
             yield from self._grads
 
 
     @torch.no_grad()
     @torch.no_grad()
@@ -381,7 +385,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         """add current gradients to grad accumulators (if any)"""
         """add current gradients to grad accumulators (if any)"""
         if self.reuse_grad_buffers:
         if self.reuse_grad_buffers:
             # user is responsible for accumulating gradients in .grad buffers
             # user is responsible for accumulating gradients in .grad buffers
-            assert batch_size == self.batch_size_per_step, "Custom batch size is not implemented for reuse_grad_buffers"
+            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
         else:
         else:
             alpha = float(batch_size) / self.batch_size_per_step
             alpha = float(batch_size) / self.batch_size_per_step
             for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
             for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):

+ 1 - 0
hivemind/optim/grad_scaler.py

@@ -13,6 +13,7 @@ logger = get_logger(__name__)
 
 
 class HivemindGradScaler(TorchGradScaler):
 class HivemindGradScaler(TorchGradScaler):
     """A thin wrapper over GradScaler that supports hivemind-style training with CollaborativeOptimizer and others"""
     """A thin wrapper over GradScaler that supports hivemind-style training with CollaborativeOptimizer and others"""
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
         self._is_running_global_step = False