dht_swarms.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import asyncio
  2. import multiprocessing as mp
  3. import random
  4. import signal
  5. import threading
  6. from typing import Dict, List, Tuple
  7. from multiaddr import Multiaddr
  8. from hivemind.dht import DHT
  9. from hivemind.dht.node import DHTID, DHTNode
  10. from hivemind.p2p import PeerID
  11. def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
  12. if asyncio.get_event_loop().is_running():
  13. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  14. asyncio.set_event_loop(asyncio.new_event_loop())
  15. loop = asyncio.get_event_loop()
  16. node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
  17. maddrs = loop.run_until_complete(node.get_visible_maddrs())
  18. info_queue.put((node.node_id, node.peer_id, maddrs))
  19. async def shutdown():
  20. await node.shutdown()
  21. loop.stop()
  22. loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
  23. loop.run_forever()
  24. def launch_swarm_in_separate_processes(
  25. n_peers: int, n_sequential_peers: int
  26. ) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
  27. assert (
  28. n_sequential_peers < n_peers
  29. ), "Parameters imply that first n_sequential_peers of n_peers will be run sequentially"
  30. processes = []
  31. dht = {}
  32. swarm_maddrs = []
  33. info_queue = mp.Queue()
  34. info_lock = mp.RLock()
  35. for _ in range(n_sequential_peers):
  36. initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
  37. proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
  38. proc.start()
  39. processes.append(proc)
  40. node_id, peer_endpoint, peer_maddrs = info_queue.get()
  41. dht[peer_endpoint] = node_id
  42. swarm_maddrs.append(peer_maddrs)
  43. def collect_info():
  44. while True:
  45. node_id, peer_endpoint, peer_maddrs = info_queue.get()
  46. with info_lock:
  47. dht[peer_endpoint] = node_id
  48. swarm_maddrs.append(peer_maddrs)
  49. if len(dht) == n_peers:
  50. break
  51. collect_thread = threading.Thread(target=collect_info)
  52. collect_thread.start()
  53. for _ in range(n_peers - n_sequential_peers):
  54. with info_lock:
  55. initial_peers = random.choice(swarm_maddrs)
  56. proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
  57. proc.start()
  58. processes.append(proc)
  59. collect_thread.join()
  60. return processes, dht, swarm_maddrs
  61. async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
  62. nodes = [await DHTNode.create(**kwargs)]
  63. initial_peers = await nodes[0].get_visible_maddrs()
  64. nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
  65. return nodes
  66. def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
  67. dhts = [DHT(start=True, **kwargs)]
  68. initial_peers = dhts[0].get_visible_maddrs()
  69. dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
  70. for instance in dhts[1:]:
  71. instance.ready.wait()
  72. return dhts