Jelajahi Sumber

notify peers if averaging round while awaiting trigger

justheuristic 3 tahun lalu
induk
melakukan
8840aaab8d

+ 10 - 5
hivemind/averaging/averager.py

@@ -422,10 +422,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         await step.wait_for_trigger()
                         await step.wait_for_trigger()
                     return group_info
                     return group_info
                 except asyncio.CancelledError:
                 except asyncio.CancelledError:
-                    return asyncio.wait(
+                    await asyncio.wait({
                         self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
                         self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
-                        for peer_id in group_info.peer_ids
-                    )
+                        for peer_id in group_info.peer_ids if peer_id != self.peer_id
+                    })
+                    raise
 
 
             while not step.done():
             while not step.done():
                 try:
                 try:
@@ -490,8 +491,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 )
                 )
 
 
     async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
     async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=group_id, code=code)
-        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        try:
+            error = averaging_pb2.AveragingData(group_id=group_id, code=code)
+            stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
+            await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")
 
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""

+ 8 - 9
hivemind/optim/experimental/optimizer.py

@@ -155,7 +155,6 @@ class Optimizer(torch.optim.Optimizer):
         extra_tensors: Sequence[torch.Tensor] = (),
         extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
-        preschedule_state_averaging: bool = False,
         performance_ema_alpha: float = 0.1,
         performance_ema_alpha: float = 0.1,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
         verbose: bool = False,
         verbose: bool = False,
@@ -190,7 +189,6 @@ class Optimizer(torch.optim.Optimizer):
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
-        self.preschedule_state_averaging = preschedule_state_averaging
 
 
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
@@ -350,8 +348,7 @@ class Optimizer(torch.optim.Optimizer):
                     return loss  # local gradients were reset due to overflow, must start over
                     return loss  # local gradients were reset due to overflow, must start over
 
 
             self._maybe_schedule_gradient_averaging()
             self._maybe_schedule_gradient_averaging()
-            if self.preschedule_state_averaging:
-                self._maybe_schedule_state_averaging()
+            self._maybe_schedule_state_averaging()
 
 
         else:
         else:
             # use_local_updates=True: update parameters on every step independently of other peers
             # use_local_updates=True: update parameters on every step independently of other peers
@@ -362,8 +359,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
                 new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
                 new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
                 self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
                 self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
-                if self.preschedule_state_averaging:
-                    self._maybe_schedule_state_averaging()
+                self._maybe_schedule_state_averaging()
 
 
                 self.state_averager.step(
                 self.state_averager.step(
                     increment_epoch=False,
                     increment_epoch=False,
@@ -400,7 +396,9 @@ class Optimizer(torch.optim.Optimizer):
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
             should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
             should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
-            should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
+            should_average_state = (swarm_not_empty and
+                                    next_epoch % self.average_state_every == 0 and
+                                    not self.state_averager.averaging_in_progress)
 
 
             if should_average_state and self.scheduled_state is not None:
             if should_average_state and self.scheduled_state is not None:
                 if self.scheduled_state.triggered or self.scheduled_state.done():
                 if self.scheduled_state.triggered or self.scheduled_state.done():
@@ -410,7 +408,6 @@ class Optimizer(torch.optim.Optimizer):
                         f"was already used elsewhere: {self.scheduled_state}",
                         f"was already used elsewhere: {self.scheduled_state}",
                     )
                     )
                     self.scheduled_state = None
                     self.scheduled_state = None
-
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
                 self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
 
 
             self.state_averager.step(
             self.state_averager.step(
@@ -502,10 +499,11 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def _maybe_schedule_state_averaging(self) -> None:
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
-        assert self.preschedule_state_averaging
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
             return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
 
 
         estimated_time = self.tracker.estimated_next_update_time
         estimated_time = self.tracker.estimated_next_update_time
         estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
         estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
@@ -599,6 +597,7 @@ class Optimizer(torch.optim.Optimizer):
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         """Attempt to fetch the newest collaboration state from other peers"""
         self._finish_background_averaging()
         self._finish_background_averaging()
+        self.state_averager.step(wait_for_delayed_updates=True)
 
 
         with self.tracker.pause_updates():
         with self.tracker.pause_updates():
             while True:
             while True:

+ 23 - 17
hivemind/optim/experimental/state_averager.py

@@ -123,6 +123,7 @@ class TrainingStateAverager(DecentralizedAverager):
         self.finished_optimizer_step = threading.Event()
         self.finished_optimizer_step = threading.Event()
         self.finished_averaging_round = threading.Event()
         self.finished_averaging_round = threading.Event()
         self.lock_optimizer = threading.Lock()
         self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
         self.pending_updates = set()
         self.pending_updates = set()
 
 
         super().__init__(
         super().__init__(
@@ -509,23 +510,24 @@ class TrainingStateAverager(DecentralizedAverager):
                 self.finished_optimizer_step.set()
                 self.finished_optimizer_step.set()
 
 
             if averaging_round:
             if averaging_round:
-                if not self.reuse_tensors:
-                    self._load_local_tensors_into_averager_()
-                if self.delta_rule_averaging:
-                    # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
-                    with torch.no_grad(), self.get_tensors() as averaged_tensors:
-                        self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
-
-                self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
-                try:
-                    averaging_control.allow_allreduce()
-                    gathered = averaging_control.result(timeout=timeout)
-                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    gathered = {}
-
-                self.finished_averaging_round.set()
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
+
+                    self.finished_averaging_round.set()
 
 
                 if self.sync_epoch_when_averaging:
                 if self.sync_epoch_when_averaging:
                     old_epoch = self.local_epoch
                     old_epoch = self.local_epoch
@@ -589,6 +591,10 @@ class TrainingStateAverager(DecentralizedAverager):
                     delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
                     delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
                     local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
                     local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
 
 
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
+
     def get_current_state(self):
     def get_current_state(self):
         """
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.