|
@@ -119,9 +119,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
|
|
|
assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
|
|
|
- super().__init__()
|
|
|
+ mp.Process.__init__(self)
|
|
|
+ ServicerBase.__init__(self)
|
|
|
self.dht = dht
|
|
|
- self.p2p = dht.p2p
|
|
|
self.client_mode = client_mode
|
|
|
self._parent_pid = os.getpid()
|
|
|
if self.client_mode:
|
|
@@ -191,9 +191,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def endpoint(self) -> Endpoint:
|
|
|
return self.p2p.id
|
|
|
|
|
|
- def __repr__(self):
|
|
|
- return f"{self.__class__.__name__}({self.endpoint})"
|
|
|
-
|
|
|
def run(self):
|
|
|
"""
|
|
|
Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
|
|
@@ -211,6 +208,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
|
|
|
async def _run():
|
|
|
+ self.p2p = await self.dht.replicate_p2p()
|
|
|
if not self.client_mode:
|
|
|
await self.add_p2p_handlers(self.p2p)
|
|
|
else:
|
|
@@ -473,7 +471,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
asyncio.wait_for(
|
|
|
self.dht.store(
|
|
|
download_key,
|
|
|
- subkey=self.endpoint,
|
|
|
+ subkey=self.endpoint.to_base58(),
|
|
|
value=self.last_updated,
|
|
|
expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
|
|
|
return_future=True,
|
|
@@ -539,7 +537,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
key_manager = self._matchmaking.group_key_manager
|
|
|
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
|
peer_priority = {
|
|
|
- peer: float(info.value)
|
|
|
+ Endpoint.from_base58(peer): float(info.value)
|
|
|
for peer, info in peer_priority.items()
|
|
|
if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
|
|
|
}
|
|
@@ -553,7 +551,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
|
|
|
if peer != self.endpoint:
|
|
|
logger.info(f"Downloading parameters from peer {peer}")
|
|
|
- stream = None
|
|
|
try:
|
|
|
stub = self.get_stub(self.p2p, peer)
|
|
|
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
@@ -579,9 +576,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
return
|
|
|
except BaseException as e:
|
|
|
logger.exception(f"Failed to download state from {peer} - {repr(e)}")
|
|
|
- finally:
|
|
|
- if stream is not None:
|
|
|
- await stream.code()
|
|
|
|
|
|
finally:
|
|
|
if not future.done():
|