|
@@ -31,6 +31,7 @@ from hivemind.compression import (
|
|
|
)
|
|
|
from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
from hivemind.utils.asyncio import (
|
|
@@ -467,9 +468,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
MatchmakingException,
|
|
|
AssertionError,
|
|
|
StopAsyncIteration,
|
|
|
+ GeneratorExit,
|
|
|
asyncio.CancelledError,
|
|
|
asyncio.InvalidStateError,
|
|
|
P2PHandlerError,
|
|
|
+ DispatchFailure,
|
|
|
+ ControlFailure,
|
|
|
) as e:
|
|
|
if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
|
|
|
if not step.cancelled():
|
|
@@ -535,7 +539,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
# all-reduce is performed asynchronously while iterating
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- self._state_updated.set()
|
|
|
+ self._state_updated.set()
|
|
|
|
|
|
else:
|
|
|
async for _ in allreduce: # trigger all-reduce by iterating
|