Sfoglia il codice sorgente

Fix test_allreduce_once

Aleksandr Borzunov 4 anni fa
parent
commit
0384737569

+ 1 - 3
hivemind/averaging/averager.py

@@ -188,7 +188,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     @property
     def endpoint(self) -> Endpoint:
-        return self._p2p.id
+        return self.dht.peer_id
 
     def run(self):
         """
@@ -257,8 +257,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        if not self.client_mode:
-            remaining_tasks.add(self._server.stop(timeout))
         await asyncio.gather(*remaining_tasks)
 
     def __del__(self):

+ 10 - 6
hivemind/averaging/matchmaking.py

@@ -237,7 +237,6 @@ class Matchmaking:
         self, request: averaging_pb2.JoinRequest, _: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
-        request_endpoint = PeerID.from_base58(request.endpoint)
         try:
             async with self.lock_request_join_group:
                 reason_to_reject = self._check_reasons_to_reject(request)
@@ -245,6 +244,7 @@ class Matchmaking:
                     yield reason_to_reject
                     return
 
+                request_endpoint = PeerID.from_base58(request.endpoint)
                 self.current_followers[request_endpoint] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
@@ -289,7 +289,7 @@ class Matchmaking:
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 group_id=group_info.group_id,
-                ordered_group_endpoints=group_info.endpoints,
+                ordered_group_endpoints=[item.to_base58() for item in group_info.endpoints],
                 gathered=group_info.gathered,
             )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
@@ -309,14 +309,17 @@ class Matchmaking:
         if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
+        try:
+            request_endpoint = PeerID.from_base58(request.endpoint)
+        except (ValueError, TypeError):
+            request_endpoint = None
         if (
             request.ListFields() == 3
             and not isinstance(request.schema_hash, bytes)
             or len(request.schema_hash) == 0
             or not isinstance(request.expiration, DHTExpiration)
             or not isfinite(request.expiration)
-            or not isinstance(request.endpoint, Endpoint)
-            or len(request.endpoint) == 0
+            or request_endpoint is None
             or self.client_mode
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
@@ -331,7 +334,7 @@ class Matchmaking:
             return averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
             )  # note: this suggested leader is currently ignored
-        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
+        elif request_endpoint == self.endpoint or request_endpoint in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
         elif len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
@@ -364,7 +367,8 @@ class Matchmaking:
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
-        group_id, ordered_group_endpoints = msg.group_id, msg.ordered_group_endpoints
+        group_id = msg.group_id
+        ordered_group_endpoints = [Endpoint.from_base58(item) for item in msg.ordered_group_endpoints]
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert len(ordered_group_endpoints) == len(msg.gathered)
 

+ 16 - 6
hivemind/dht/__init__.py

@@ -24,7 +24,7 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 from multiaddr import Multiaddr
 
 from hivemind.dht.node import DHTNode
-from hivemind.p2p import P2P
+from hivemind.p2p import P2P, PeerID
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
@@ -86,7 +86,10 @@ class DHT(mp.Process):
         self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
         self.daemon = daemon
+
+        self._peer_id = None
         self._p2p_replica = None
+
         if start:
             self.run_in_background(await_ready=True)
 
@@ -256,6 +259,15 @@ class DHT(mp.Process):
     async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
         node.protocol.record_validator.extend(record_validators)
 
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    async def _get_peer_id(self, node: DHTNode) -> PeerID:
+        return node.peer_id
+
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         Get multiaddrs of the current DHT node that should be accessible by other peers.
@@ -273,11 +285,9 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         """
 
-        if self._p2p_replica is not None:
-            return self._p2p_replica
-
-        daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-        self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica
 
     async def _get_p2p_daemon_listen_maddr(self, node: DHTNode) -> Multiaddr:

+ 13 - 9
tests/test_averaging.py

@@ -51,8 +51,6 @@ async def test_key_manager():
 
 
 def _test_allreduce_once(n_clients, n_aux):
-    dht = hivemind.DHT(start=True)
-
     n_peers = 4
     modes = (
         [AveragingMode.CLIENT] * n_clients
@@ -73,19 +71,23 @@ def _test_allreduce_once(n_clients, n_aux):
         for i in range(len(tensors1))
     ]
 
-    averagers = [
-        hivemind.averaging.DecentralizedAverager(
+    dht_root = hivemind.DHT(start=True)
+    initial_peers = dht_root.get_visible_maddrs()
+    averagers = []
+    dhts = []
+    for tensors, mode in zip(peer_tensors, modes):
+        dht_instance = hivemind.DHT(start=True, initial_peers=initial_peers)
+        dhts.append(dht_instance)
+        averagers.append(hivemind.averaging.DecentralizedAverager(
             tensors,
-            dht=dht,
+            dht=dht_instance,
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
             auxiliary=mode == AveragingMode.AUX,
             start=True,
-        )
-        for tensors, mode in zip(peer_tensors, modes)
-    ]
+        ))
 
     futures = []
     for averager in averagers:
@@ -103,7 +105,9 @@ def _test_allreduce_once(n_clients, n_aux):
 
     for averager in averagers:
         averager.shutdown()
-    dht.shutdown()
+    for instance in dhts:
+        instance.shutdown()
+    dht_root.shutdown()
 
 
 @pytest.mark.forked