瀏覽代碼

aux peers

justheuristic 3 年之前
父節點
當前提交
3ff6c59de7
共有 2 個文件被更改,包括 37 次插入30 次删除
  1. 34 28
      hivemind/optim/experimental/optimizer.py
  2. 3 2
      hivemind/optim/experimental/progress_tracker.py

+ 34 - 28
hivemind/optim/experimental/optimizer.py

@@ -95,22 +95,28 @@ class Optimizer(torch.optim.Optimizer):
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 300.0,
-        average_state_every: int = 1,
         load_state_timeout: float = 600.0,
+        average_state_every: int = 1,
         reuse_grad_buffers: bool = False,
         delay_optimizer_step: bool = False,
         client_mode: bool = None,
+        auxiliary: bool = False,
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
+        shutdown_timeout: float = 5,
         verbose: bool = False,
     ):
-        self.dht, self.prefix = dht, prefix
-        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
+        client_mode = client_mode if client_mode is None else dht.client_mode
+        assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
+        assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+
+        self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.matchmaking_time, self.delay_optimizer_step = matchmaking_time, delay_optimizer_step
         self.average_state_every = average_state_every
+        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
+        self.shutdown_timeout = shutdown_timeout
 
-        self.client_mode = client_mode if client_mode is not None else self.dht.client_mode
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_round: Optional[StepControl] = None
 
@@ -132,8 +138,10 @@ class Optimizer(torch.optim.Optimizer):
             dht=self.dht,
             prefix=f"{self.prefix}_state_averager",
             allreduce_timeout=self.averaging_timeout,
+            shutdown_timeout=self.shutdown_timeout,
             status_loglevel=self.status_loglevel,
             client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
             offload_optimizer=True,
             custom_gradients=True,
             start=True,
@@ -147,7 +155,9 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.prefix}_grad_averager",
             parameters=self.state_averager.main_parameters,
             allreduce_timeout=self.averaging_timeout,
+            shutdown_timeout=self.shutdown_timeout,
             client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
             start=True,
             **kwargs,
         )
@@ -207,9 +217,11 @@ class Optimizer(torch.optim.Optimizer):
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
         if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
-            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler).")
-        if self.batch_size_per_step is None and batch_size is None:
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
+        if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
             raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
+            raise ValueError("Auxiliary peers should not have batch size run closures, or use grad_scaler")
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
         loss = None
@@ -217,7 +229,7 @@ class Optimizer(torch.optim.Optimizer):
             with torch.enable_grad():
                 loss = closure()
 
-        if self.should_load_state_from_peers():
+        if not self.auxiliary and self.should_load_state_from_peers():
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             return loss
@@ -228,9 +240,10 @@ class Optimizer(torch.optim.Optimizer):
             self.grad_averager.reset_accumulated_grads_()
             return loss
 
-        self.grad_averager.accumulate_grads_(batch_size)
-        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
-        self.state_averager.step(apply_delayed_updates=True)
+        if not self.auxiliary:
+            self.grad_averager.accumulate_grads_(batch_size)
+            self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+            self.state_averager.step(apply_delayed_updates=True)
 
         if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
             if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
@@ -280,11 +293,11 @@ class Optimizer(torch.optim.Optimizer):
 
                 self.state_averager.step(
                     increment_epoch=True,
-                    optimizer_step=True,
+                    optimizer_step=not self.auxiliary,
                     delay_optimizer_step=self.delay_optimizer_step,
-                    grad_scaler=grad_scaler,
                     averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
-                    delay_averaging=True,
+                    delay_averaging=not self.auxiliary,
+                    grad_scaler=grad_scaler,
                     averaging_opts=dict(
                         scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
                     )
@@ -292,25 +305,18 @@ class Optimizer(torch.optim.Optimizer):
                     else None,
                 )
 
-            self.grad_averager.reset_accumulated_grads_()
-            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            if not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
 
-            if self.should_load_state_from_peers(new_epoch=True):
-                logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
-                self.load_state_from_peers()
-                return loss
+                if self.should_load_state_from_peers(new_epoch=True):
+                    logger.log(self.status_loglevel, "Peer ended up out of sync after averaging.")
+                    self.load_state_from_peers()
+                    return loss
 
             logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
         return loss
 
-    def step_aux(self, **kwargs):
-        """
-        Find and assist other peers in averaging without sending local gradients.
-
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-        raise NotImplementedError("Auxiliary step for hivemind.Optimizer is not implemented yet.")
-
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
         if self.grad_averager.reuse_grad_buffers:
@@ -390,7 +396,7 @@ class Optimizer(torch.optim.Optimizer):
 
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")
-        self.tracker.shutdown()
+        self.tracker.shutdown(self.shutdown_timeout)
         logger.debug("Shutting down averager...")
         self.state_averager.step(wait_for_delayed_update=True)
         self.state_averager.shutdown()

+ 3 - 2
hivemind/optim/experimental/progress_tracker.py

@@ -307,15 +307,16 @@ class ProgressTracker(threading.Thread):
             next_fetch_time=current_time + time_to_next_fetch,
         )
 
-    def shutdown(self):
+    def shutdown(self, timeout: Optional[float]=None):
         """Permanently disable all tracking activity"""
         self.shutdown_triggered.set()
         self.should_report_progress.set()
         self.global_state_updated.set()
-        self.shutdown_complete.wait()
+        self.shutdown_complete.wait(timeout)
         self.dht.store(
             self.training_progress_key,
             subkey=self._local_public_key,
             value=None,
             expiration_time=get_dht_time() + self.metadata_expiration,
+            return_future=True
         )