Explorar o código

Fix test_allreduce_once

Aleksandr Borzunov %!s(int64=4) %!d(string=hai) anos
pai
achega
0384737569

+ 1 - 3
hivemind/averaging/averager.py

@@ -188,7 +188,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
     @property
     @property
     def endpoint(self) -> Endpoint:
     def endpoint(self) -> Endpoint:
-        return self._p2p.id
+        return self.dht.peer_id
 
 
     def run(self):
     def run(self):
         """
         """
@@ -257,8 +257,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         remaining_tasks = set()
         remaining_tasks = set()
         for group in self._running_groups.values():
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
             remaining_tasks.update(group.finalize(cancel=True))
-        if not self.client_mode:
-            remaining_tasks.add(self._server.stop(timeout))
         await asyncio.gather(*remaining_tasks)
         await asyncio.gather(*remaining_tasks)
 
 
     def __del__(self):
     def __del__(self):

+ 10 - 6
hivemind/averaging/matchmaking.py

@@ -237,7 +237,6 @@ class Matchmaking:
         self, request: averaging_pb2.JoinRequest, _: P2PContext
         self, request: averaging_pb2.JoinRequest, _: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
-        request_endpoint = PeerID.from_base58(request.endpoint)
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 reason_to_reject = self._check_reasons_to_reject(request)
                 reason_to_reject = self._check_reasons_to_reject(request)
@@ -245,6 +244,7 @@ class Matchmaking:
                     yield reason_to_reject
                     yield reason_to_reject
                     return
                     return
 
 
+                request_endpoint = PeerID.from_base58(request.endpoint)
                 self.current_followers[request_endpoint] = request
                 self.current_followers[request_endpoint] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
 
@@ -289,7 +289,7 @@ class Matchmaking:
             yield averaging_pb2.MessageFromLeader(
             yield averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 code=averaging_pb2.BEGIN_ALLREDUCE,
                 group_id=group_info.group_id,
                 group_id=group_info.group_id,
-                ordered_group_endpoints=group_info.endpoints,
+                ordered_group_endpoints=[item.to_base58() for item in group_info.endpoints],
                 gathered=group_info.gathered,
                 gathered=group_info.gathered,
             )
             )
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
         except (concurrent.futures.CancelledError, asyncio.CancelledError):
@@ -309,14 +309,17 @@ class Matchmaking:
         if not self.is_looking_for_group or self.assembled_group.done():
         if not self.is_looking_for_group or self.assembled_group.done():
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)
 
 
+        try:
+            request_endpoint = PeerID.from_base58(request.endpoint)
+        except (ValueError, TypeError):
+            request_endpoint = None
         if (
         if (
             request.ListFields() == 3
             request.ListFields() == 3
             and not isinstance(request.schema_hash, bytes)
             and not isinstance(request.schema_hash, bytes)
             or len(request.schema_hash) == 0
             or len(request.schema_hash) == 0
             or not isinstance(request.expiration, DHTExpiration)
             or not isinstance(request.expiration, DHTExpiration)
             or not isfinite(request.expiration)
             or not isfinite(request.expiration)
-            or not isinstance(request.endpoint, Endpoint)
-            or len(request.endpoint) == 0
+            or request_endpoint is None
             or self.client_mode
             or self.client_mode
         ):
         ):
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)
@@ -331,7 +334,7 @@ class Matchmaking:
             return averaging_pb2.MessageFromLeader(
             return averaging_pb2.MessageFromLeader(
                 code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
                 code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
             )  # note: this suggested leader is currently ignored
             )  # note: this suggested leader is currently ignored
-        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
+        elif request_endpoint == self.endpoint or request_endpoint in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
         elif len(self.current_followers) + 1 >= self.target_group_size:
         elif len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
@@ -364,7 +367,8 @@ class Matchmaking:
         assert not self.assembled_group.done()
         assert not self.assembled_group.done()
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
         assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})"
 
 
-        group_id, ordered_group_endpoints = msg.group_id, msg.ordered_group_endpoints
+        group_id = msg.group_id
+        ordered_group_endpoints = [Endpoint.from_base58(item) for item in msg.ordered_group_endpoints]
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!"
         assert len(ordered_group_endpoints) == len(msg.gathered)
         assert len(ordered_group_endpoints) == len(msg.gathered)
 
 

+ 16 - 6
hivemind/dht/__init__.py

@@ -24,7 +24,7 @@ from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, Type
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 from hivemind.dht.node import DHTNode
 from hivemind.dht.node import DHTNode
-from hivemind.p2p import P2P
+from hivemind.p2p import P2P, PeerID
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
 from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
@@ -86,7 +86,10 @@ class DHT(mp.Process):
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
         self.ready = mp.Event()
         self.ready = mp.Event()
         self.daemon = daemon
         self.daemon = daemon
+
+        self._peer_id = None
         self._p2p_replica = None
         self._p2p_replica = None
+
         if start:
         if start:
             self.run_in_background(await_ready=True)
             self.run_in_background(await_ready=True)
 
 
@@ -256,6 +259,15 @@ class DHT(mp.Process):
     async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
     async def _add_validators(self, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
         node.protocol.record_validator.extend(record_validators)
         node.protocol.record_validator.extend(record_validators)
 
 
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    async def _get_peer_id(self, node: DHTNode) -> PeerID:
+        return node.peer_id
+
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
     def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
         """
         """
         Get multiaddrs of the current DHT node that should be accessible by other peers.
         Get multiaddrs of the current DHT node that should be accessible by other peers.
@@ -273,11 +285,9 @@ class DHT(mp.Process):
         Get a replica of a P2P instance used in the DHT process internally.
         Get a replica of a P2P instance used in the DHT process internally.
         """
         """
 
 
-        if self._p2p_replica is not None:
-            return self._p2p_replica
-
-        daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-        self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
         return self._p2p_replica
         return self._p2p_replica
 
 
     async def _get_p2p_daemon_listen_maddr(self, node: DHTNode) -> Multiaddr:
     async def _get_p2p_daemon_listen_maddr(self, node: DHTNode) -> Multiaddr:

+ 13 - 9
tests/test_averaging.py

@@ -51,8 +51,6 @@ async def test_key_manager():
 
 
 
 
 def _test_allreduce_once(n_clients, n_aux):
 def _test_allreduce_once(n_clients, n_aux):
-    dht = hivemind.DHT(start=True)
-
     n_peers = 4
     n_peers = 4
     modes = (
     modes = (
         [AveragingMode.CLIENT] * n_clients
         [AveragingMode.CLIENT] * n_clients
@@ -73,19 +71,23 @@ def _test_allreduce_once(n_clients, n_aux):
         for i in range(len(tensors1))
         for i in range(len(tensors1))
     ]
     ]
 
 
-    averagers = [
-        hivemind.averaging.DecentralizedAverager(
+    dht_root = hivemind.DHT(start=True)
+    initial_peers = dht_root.get_visible_maddrs()
+    averagers = []
+    dhts = []
+    for tensors, mode in zip(peer_tensors, modes):
+        dht_instance = hivemind.DHT(start=True, initial_peers=initial_peers)
+        dhts.append(dht_instance)
+        averagers.append(hivemind.averaging.DecentralizedAverager(
             tensors,
             tensors,
-            dht=dht,
+            dht=dht_instance,
             target_group_size=4,
             target_group_size=4,
             averaging_expiration=15,
             averaging_expiration=15,
             prefix="mygroup",
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
             client_mode=mode == AveragingMode.CLIENT,
             auxiliary=mode == AveragingMode.AUX,
             auxiliary=mode == AveragingMode.AUX,
             start=True,
             start=True,
-        )
-        for tensors, mode in zip(peer_tensors, modes)
-    ]
+        ))
 
 
     futures = []
     futures = []
     for averager in averagers:
     for averager in averagers:
@@ -103,7 +105,9 @@ def _test_allreduce_once(n_clients, n_aux):
 
 
     for averager in averagers:
     for averager in averagers:
         averager.shutdown()
         averager.shutdown()
-    dht.shutdown()
+    for instance in dhts:
+        instance.shutdown()
+    dht_root.shutdown()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked