Browse Source

do not crash Averager.step on GeneratorExit

justheuristic 3 years ago
parent
commit
f33aefaee1

+ 5 - 1
hivemind/averaging/averager.py

@@ -31,6 +31,7 @@ from hivemind.compression import (
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
@@ -467,9 +468,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     MatchmakingException,
                     AssertionError,
                     StopAsyncIteration,
+                    GeneratorExit,
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
+                    DispatchFailure,
+                    ControlFailure,
                 ) as e:
                     if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
                         if not step.cancelled():
@@ -535,7 +539,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                            self._state_updated.set()
+                        self._state_updated.set()
 
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating

+ 8 - 2
hivemind/averaging/matchmaking.py

@@ -227,7 +227,10 @@ class Matchmaking:
                     if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
-                        await stream.aclose()
+                        try:
+                            await stream.aclose()
+                        except RuntimeError as e:
+                            logger.debug(e, exc_info=True)
                         return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
@@ -245,7 +248,10 @@ class Matchmaking:
             self.was_accepted_to_group.clear()
             self.current_leader = None
             if stream is not None:
-                await stream.aclose()
+                try:
+                    await stream.aclose()
+                except RuntimeError as e:
+                    logger.debug(e, exc_info=True)
 
     def get_request_expiration_time(self) -> float:
         """Returns the averager's current expiration time, which is used to send join requests to leaders"""

+ 1 - 1
hivemind/optim/experimental/optimizer.py

@@ -468,7 +468,7 @@ class Optimizer(torch.optim.Optimizer):
                 averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
 
-            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.triggered:
                 self.scheduled_state.cancel()
             self.scheduled_state = None