Przeglądaj źródła

Speed up DHT swarm creation

Aleksandr Borzunov 4 lat temu
rodzic
commit
2ae476fe96

+ 3 - 1
hivemind/dht/__init__.py

@@ -50,6 +50,7 @@ class DHT(mp.Process):
       The validators will be combined using the CompositeValidator class. It merges them when possible
       The validators will be combined using the CompositeValidator class. It merges them when possible
       (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
       (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     """
     """
 
 
@@ -64,6 +65,7 @@ class DHT(mp.Process):
         max_workers: Optional[int] = None,
         max_workers: Optional[int] = None,
         record_validators: Iterable[RecordValidatorBase] = (),
         record_validators: Iterable[RecordValidatorBase] = (),
         shutdown_timeout: float = 3,
         shutdown_timeout: float = 3,
+        await_ready: bool = True,
         **kwargs,
         **kwargs,
     ):
     ):
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
@@ -91,7 +93,7 @@ class DHT(mp.Process):
         self._p2p_replica = None
         self._p2p_replica = None
 
 
         if start:
         if start:
-            self.run_in_background(await_ready=True)
+            self.run_in_background(await_ready=await_ready)
 
 
     def run(self) -> None:
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
         """Serve DHT forever. This function will not return until DHT node is shut down"""

+ 3 - 2
tests/test_averaging.py

@@ -290,8 +290,9 @@ def test_allgather():
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
 
     gathered_data = [future.result() for future in futures]
     gathered_data = [future.result() for future in futures]
-    gathered_data_reprs = [repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()}))
-                           for result in gathered_data]
+    gathered_data_reprs = [
+        repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()})) for result in gathered_data
+    ]
     assert len(set(gathered_data_reprs)) == 2
     assert len(set(gathered_data_reprs)) == 2
 
 
     reference_metadata = {
     reference_metadata = {

+ 7 - 5
tests/test_utils/dht_swarms.py

@@ -90,9 +90,11 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
 
 
 
 
 def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
 def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
-    # TODO: Do it in parallel
+    dhts = [DHT(start=True, **kwargs)]
+    initial_peers = dhts[0].get_visible_maddrs()
 
 
-    instances = [DHT(start=True, **kwargs)]
-    initial_peers = instances[0].get_visible_maddrs()
-    instances.extend(DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
-    return instances
+    dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
+    for instance in dhts[1:]:
+        instance.ready.wait()
+
+    return dhts