|
@@ -15,7 +15,7 @@ from hivemind.dht import DHT, DHTID, DHTExpiration
|
|
|
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
|
|
|
-from hivemind.utils.asyncio import anext
|
|
|
+from hivemind.utils.asyncio import anext, cancel_and_wait
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -127,10 +127,9 @@ class Matchmaking:
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
- if not request_leaders_task.done():
|
|
|
- request_leaders_task.cancel()
|
|
|
- if not self.assembled_group.done():
|
|
|
- self.assembled_group.cancel()
|
|
|
+ await cancel_and_wait(request_leaders_task)
|
|
|
+ self.assembled_group.cancel()
|
|
|
+
|
|
|
while len(self.current_followers) > 0:
|
|
|
await self.follower_was_discarded.wait()
|
|
|
self.follower_was_discarded.clear()
|
|
@@ -229,7 +228,7 @@ class Matchmaking:
|
|
|
logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
|
|
|
return None
|
|
|
except (P2PHandlerError, StopAsyncIteration) as e:
|
|
|
- logger.error(f"{self} - failed to request potential leader {leader}: {e}")
|
|
|
+ logger.exception(f"{self} - failed to request potential leader {leader}:")
|
|
|
return None
|
|
|
|
|
|
finally:
|
|
@@ -413,10 +412,9 @@ class PotentialLeaders:
|
|
|
try:
|
|
|
yield self
|
|
|
finally:
|
|
|
- if not update_queue_task.done():
|
|
|
- update_queue_task.cancel()
|
|
|
- if declare and not declare_averager_task.done():
|
|
|
- declare_averager_task.cancel()
|
|
|
+ await cancel_and_wait(update_queue_task)
|
|
|
+ if declare:
|
|
|
+ await cancel_and_wait(declare_averager_task)
|
|
|
|
|
|
for field in (
|
|
|
self.past_attempts,
|
|
@@ -477,37 +475,31 @@ class PotentialLeaders:
|
|
|
else:
|
|
|
return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
|
|
|
|
|
|
- async def _update_queue_periodically(self, key_manager: GroupKeyManager):
|
|
|
- try:
|
|
|
- DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
|
- while get_dht_time() < self.search_end_time:
|
|
|
- new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
|
|
|
- self.max_assured_time = max(
|
|
|
- self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
|
|
|
- )
|
|
|
+ async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
|
|
|
+ DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
|
|
|
+ while get_dht_time() < self.search_end_time:
|
|
|
+ new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
|
|
|
+ self.max_assured_time = max(
|
|
|
+ self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
|
|
|
+ )
|
|
|
|
|
|
- self.leader_queue.clear()
|
|
|
- for peer, peer_expiration_time in new_peers:
|
|
|
- if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
|
|
|
- continue
|
|
|
- self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
|
- self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
|
+ self.leader_queue.clear()
|
|
|
+ for peer, peer_expiration_time in new_peers:
|
|
|
+ if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
|
|
|
+ continue
|
|
|
+ self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
|
|
|
+ self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
|
|
|
|
|
|
- self.update_finished.set()
|
|
|
+ self.update_finished.set()
|
|
|
|
|
|
- await asyncio.wait(
|
|
|
- {self.running.wait(), self.update_triggered.wait()},
|
|
|
- return_when=asyncio.ALL_COMPLETED,
|
|
|
- timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
|
|
|
- )
|
|
|
- self.update_triggered.clear()
|
|
|
- except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
|
- return # note: this is a compatibility layer for python3.7
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
|
|
|
- raise
|
|
|
+ await asyncio.wait(
|
|
|
+ {self.running.wait(), self.update_triggered.wait()},
|
|
|
+ return_when=asyncio.ALL_COMPLETED,
|
|
|
+ timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
|
|
|
+ )
|
|
|
+ self.update_triggered.clear()
|
|
|
|
|
|
- async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
|
|
|
+ async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
|
|
|
async with self.lock_declare:
|
|
|
try:
|
|
|
while True:
|
|
@@ -521,10 +513,6 @@ class PotentialLeaders:
|
|
|
await asyncio.sleep(self.declared_expiration_time - get_dht_time())
|
|
|
if self.running.is_set() and len(self.leader_queue) == 0:
|
|
|
await key_manager.update_key_on_not_enough_peers()
|
|
|
- except (concurrent.futures.CancelledError, asyncio.CancelledError):
|
|
|
- pass # note: this is a compatibility layer for python3.7
|
|
|
- except Exception as e: # note: we catch exceptions here because otherwise they are never printed
|
|
|
- logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
|
|
|
finally:
|
|
|
if self.declared_group_key is not None:
|
|
|
prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
|