Kaynağa Gözat

notify peers if averaging round while awaiting trigger

justheuristic 3 yıl önce
ebeveyn
işleme
8840aaab8d

+ 10 - 5
hivemind/averaging/averager.py

@@ -422,10 +422,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         await step.wait_for_trigger()
                     return group_info
                 except asyncio.CancelledError:
-                    return asyncio.wait(
+                    await asyncio.wait({
                         self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
-                        for peer_id in group_info.peer_ids
-                    )
+                        for peer_id in group_info.peer_ids if peer_id != self.peer_id
+                    })
+                    raise
 
             while not step.done():
                 try:
@@ -490,8 +491,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 )
 
     async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=group_id, code=code)
-        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        try:
+            error = averaging_pb2.AveragingData(group_id=group_id, code=code)
+            stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
+            await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""

+ 8 - 9
hivemind/optim/experimental/optimizer.py

@@ -155,7 +155,6 @@ class Optimizer(torch.optim.Optimizer):
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
-        preschedule_state_averaging: bool = False,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
         verbose: bool = False,
@@ -190,7 +189,6 @@ class Optimizer(torch.optim.Optimizer):
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
-        self.preschedule_state_averaging = preschedule_state_averaging
 
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
@@ -350,8 +348,7 @@ class Optimizer(torch.optim.Optimizer):
                     return loss  # local gradients were reset due to overflow, must start over
 
             self._maybe_schedule_gradient_averaging()
-            if self.preschedule_state_averaging:
-                self._maybe_schedule_state_averaging()
+            self._maybe_schedule_state_averaging()
 
         else:
             # use_local_updates=True: update parameters on every step independently of other peers
@@ -362,8 +359,7 @@ class Optimizer(torch.optim.Optimizer):
 
                 new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
                 self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
-                if self.preschedule_state_averaging:
-                    self._maybe_schedule_state_averaging()
+                self._maybe_schedule_state_averaging()
 
                 self.state_averager.step(
                     increment_epoch=False,
@@ -400,7 +396,9 @@ class Optimizer(torch.optim.Optimizer):
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
             should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
-            should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
+            should_average_state = (swarm_not_empty and
+                                    next_epoch % self.average_state_every == 0 and
+                                    not self.state_averager.averaging_in_progress)
 
             if should_average_state and self.scheduled_state is not None:
                 if self.scheduled_state.triggered or self.scheduled_state.done():
@@ -410,7 +408,6 @@ class Optimizer(torch.optim.Optimizer):
                         f"was already used elsewhere: {self.scheduled_state}",
                     )
                     self.scheduled_state = None
-
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
 
             self.state_averager.step(
@@ -502,10 +499,11 @@ class Optimizer(torch.optim.Optimizer):
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
-        assert self.preschedule_state_averaging
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
 
         estimated_time = self.tracker.estimated_next_update_time
         estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
@@ -599,6 +597,7 @@ class Optimizer(torch.optim.Optimizer):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         self._finish_background_averaging()
+        self.state_averager.step(wait_for_delayed_updates=True)
 
         with self.tracker.pause_updates():
             while True:

+ 23 - 17
hivemind/optim/experimental/state_averager.py

@@ -123,6 +123,7 @@ class TrainingStateAverager(DecentralizedAverager):
         self.finished_optimizer_step = threading.Event()
         self.finished_averaging_round = threading.Event()
         self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
         self.pending_updates = set()
 
         super().__init__(
@@ -509,23 +510,24 @@ class TrainingStateAverager(DecentralizedAverager):
                 self.finished_optimizer_step.set()
 
             if averaging_round:
-                if not self.reuse_tensors:
-                    self._load_local_tensors_into_averager_()
-                if self.delta_rule_averaging:
-                    # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
-                    with torch.no_grad(), self.get_tensors() as averaged_tensors:
-                        self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
-
-                self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
-                try:
-                    averaging_control.allow_allreduce()
-                    gathered = averaging_control.result(timeout=timeout)
-                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    gathered = {}
-
-                self.finished_averaging_round.set()
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
+
+                    self.finished_averaging_round.set()
 
                 if self.sync_epoch_when_averaging:
                     old_epoch = self.local_epoch
@@ -589,6 +591,10 @@ class TrainingStateAverager(DecentralizedAverager):
                     delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
                     local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
 
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
+
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.