浏览代码

Speed up DHT swarm creation

Aleksandr Borzunov 4 年之前
父节点
当前提交
2ae476fe96
共有 3 个文件被更改,包括 13 次插入8 次删除
  1. 3 1
      hivemind/dht/__init__.py
  2. 3 2
      tests/test_averaging.py
  3. 7 5
      tests/test_utils/dht_swarms.py

+ 3 - 1
hivemind/dht/__init__.py

@@ -50,6 +50,7 @@ class DHT(mp.Process):
       The validators will be combined using the CompositeValidator class. It merges them when possible
       (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
     :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
     """
 
@@ -64,6 +65,7 @@ class DHT(mp.Process):
         max_workers: Optional[int] = None,
         record_validators: Iterable[RecordValidatorBase] = (),
         shutdown_timeout: float = 3,
+        await_ready: bool = True,
         **kwargs,
     ):
         self._parent_pid = os.getpid()
@@ -91,7 +93,7 @@ class DHT(mp.Process):
         self._p2p_replica = None
 
         if start:
-            self.run_in_background(await_ready=True)
+            self.run_in_background(await_ready=await_ready)
 
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""

+ 3 - 2
tests/test_averaging.py

@@ -290,8 +290,9 @@ def test_allgather():
         futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
 
     gathered_data = [future.result() for future in futures]
-    gathered_data_reprs = [repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()}))
-                           for result in gathered_data]
+    gathered_data_reprs = [
+        repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()})) for result in gathered_data
+    ]
     assert len(set(gathered_data_reprs)) == 2
 
     reference_metadata = {

+ 7 - 5
tests/test_utils/dht_swarms.py

@@ -90,9 +90,11 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
 
 
 def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
-    # TODO: Do it in parallel
+    dhts = [DHT(start=True, **kwargs)]
+    initial_peers = dhts[0].get_visible_maddrs()
 
-    instances = [DHT(start=True, **kwargs)]
-    initial_peers = instances[0].get_visible_maddrs()
-    instances.extend(DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
-    return instances
+    dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
+    for instance in dhts[1:]:
+        instance.ready.wait()
+
+    return dhts