Quellcode durchsuchen

TODO DEBUG CONSIDER ROLLING THIS BACK

justheuristic vor 3 Jahren
Ursprung
Commit
01608cc6d0
2 geänderte Dateien mit 6 neuen und 10 gelöschten Zeilen
  1. 3 7
      hivemind/averaging/averager.py
  2. 3 3
      hivemind/optim/experimental/state_averager.py

+ 3 - 7
hivemind/averaging/averager.py

@@ -423,13 +423,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         await step.wait_for_trigger()
                     return group_info
                 except asyncio.CancelledError:
-                    await asyncio.wait(
-                        {
-                            self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
-                            for peer_id in group_info.peer_ids
-                            if peer_id != self.peer_id
-                        }
-                    )
+                    for peer_id in group_info.peer_ids:
+                        if peer_id != self.peer_id:
+                            asyncio.ensure_future(self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED))
                     raise
 
             while not step.done():

+ 3 - 3
hivemind/optim/experimental/state_averager.py

@@ -541,9 +541,9 @@ class TrainingStateAverager(DecentralizedAverager):
             if not began_running:
                 logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
             logger.exception(e)
-            if averaging_control is not None and not averaging_control.done():
-                logger.error(f"Cancelled scheduled state averaging round")
-                averaging_control.cancel()
+            if averaging_control is not None and not averaging_control.triggered:
+                averaging_control.weight = 0.0
+                averaging_control.allow_allreduce()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()