Преглед изворни кода

await for running all-reduce instead of cancelling;
hide network errors that happened while reporting other network errors.

justheuristic пре 3 година
родитељ
комит
aa31300425
2 измењених фајлова са 25 додато и 13 уклоњено
  1. 5 2
      hivemind/averaging/allreduce.py
  2. 20 11
      hivemind/optim/experimental/optimizer.py

+ 5 - 2
hivemind/averaging/allreduce.py

@@ -275,8 +275,11 @@ class AllReduceRunner(ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
 
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        try:
+            error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
+            await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 20 - 11
hivemind/optim/experimental/optimizer.py

@@ -468,7 +468,7 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
 
-            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.triggered:
                 self.scheduled_state.cancel()
             self.scheduled_state = None
 
@@ -513,7 +513,7 @@ class Optimizer(torch.optim.Optimizer):
         if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
             logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
             self.scheduled_grads.cancel()
-            self.scheduled_grads = None
+        self.scheduled_grads = None
         return began_averaging_gradients
 
     def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
@@ -646,7 +646,16 @@ class Optimizer(torch.optim.Optimizer):
 
         If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
         """
-        self._finish_background_averaging()
+        # note: we tag along for the next all-reduce because the run may have already started and cancelling it
+        # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self.scheduled_grads.weight = 0
+            self.scheduled_grads.allow_allreduce()
+            try:
+                self.scheduled_grads.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+
         self.state_averager.step(wait_for_delayed_updates=True)
 
         with self.tracker.pause_updates():
@@ -675,13 +684,6 @@ class Optimizer(torch.optim.Optimizer):
                     self.grad_averager.state_sharing_priority = self.local_epoch
 
     def _finish_background_averaging(self):
-        for scheduled_round in self.scheduled_grads, self.scheduled_state:
-            if scheduled_round is not None:
-                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
-                    scheduled_round.cancel()
-                if not scheduled_round.triggered:
-                    scheduled_round.weight = 0
-                    scheduled_round.allow_allreduce()
         self.scheduled_grads = self.scheduled_state = None
 
     def state_dict(self) -> dict:
@@ -727,7 +729,14 @@ class Optimizer(torch.optim.Optimizer):
         logger.log(self.status_loglevel, "Sending goodbye to peers...")
         self.tracker.shutdown(self.shutdown_timeout)
         self.state_averager.step(wait_for_delayed_updates=True)
-        self._finish_background_averaging()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                if not scheduled_round.triggered:
+                    scheduled_round.weight = 0
+                    scheduled_round.allow_allreduce()
+
         logger.log(self.status_loglevel, "Shutting down averagers...")
         self.state_averager.shutdown()
         if self.use_gradient_averaging: