Browse Source

Make test_load_state_from_peers work

Aleksandr Borzunov 4 years ago
parent
commit
85785b9011
2 changed files with 7 additions and 14 deletions
  1. 5 11
      hivemind/averaging/averager.py
  2. 2 3
      hivemind/dht/__init__.py

+ 5 - 11
hivemind/averaging/averager.py

@@ -119,9 +119,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
-        super().__init__()
+        mp.Process.__init__(self)
+        ServicerBase.__init__(self)
         self.dht = dht
-        self.p2p = dht.p2p
         self.client_mode = client_mode
         self._parent_pid = os.getpid()
         if self.client_mode:
@@ -191,9 +191,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def endpoint(self) -> Endpoint:
         return self.p2p.id
 
-    def __repr__(self):
-        return f"{self.__class__.__name__}({self.endpoint})"
-
     def run(self):
         """
         Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
@@ -211,6 +208,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
+                self.p2p = await self.dht.replicate_p2p()
                 if not self.client_mode:
                     await self.add_p2p_handlers(self.p2p)
                 else:
@@ -473,7 +471,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
-                            subkey=self.endpoint,
+                            subkey=self.endpoint.to_base58(),
                             value=self.last_updated,
                             expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
                             return_future=True,
@@ -539,7 +537,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                peer: float(info.value)
+                Endpoint.from_base58(peer): float(info.value)
                 for peer, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
@@ -553,7 +551,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
                 if peer != self.endpoint:
                     logger.info(f"Downloading parameters from peer {peer}")
-                    stream = None
                     try:
                         stub = self.get_stub(self.p2p, peer)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
@@ -579,9 +576,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         return
                     except BaseException as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
-                    finally:
-                        if stream is not None:
-                            await stream.code()
 
         finally:
             if not future.done():

+ 2 - 3
hivemind/dht/__init__.py

@@ -268,8 +268,7 @@ class DHT(mp.Process):
     async def _get_visible_maddrs(self, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
         return await node.get_visible_maddrs(latest=latest)
 
-    @property
-    def p2p(self) -> P2P:
+    async def replicate_p2p(self) -> P2P:
         """
         Get a replica of a P2P instance used in the DHT process internally.
         """
@@ -278,7 +277,7 @@ class DHT(mp.Process):
             return self._p2p_replica
 
         daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-        self._p2p_replica = P2P.replicate(daemon_listen_maddr)
+        self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica
 
     async def _get_p2p_daemon_listen_maddr(self, node: DHTNode) -> Multiaddr: