dht_swarms.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import asyncio
  2. import multiprocessing as mp
  3. import random
  4. import signal
  5. import threading
  6. from typing import Dict, List, Tuple
  7. from multiaddr import Multiaddr
  8. from hivemind.dht.node import DHTID, DHTNode
  9. from hivemind.p2p import PeerID
  10. def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
  11. if asyncio.get_event_loop().is_running():
  12. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  13. asyncio.set_event_loop(asyncio.new_event_loop())
  14. loop = asyncio.get_event_loop()
  15. node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
  16. maddrs = loop.run_until_complete(node.get_visible_maddrs())
  17. info_queue.put((node.node_id, node.peer_id, maddrs))
  18. async def shutdown():
  19. await node.shutdown()
  20. loop.stop()
  21. loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
  22. loop.run_forever()
  23. def launch_swarm_in_separate_processes(
  24. n_peers: int, n_sequential_peers: int
  25. ) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
  26. assert (
  27. n_sequential_peers < n_peers
  28. ), "Parameters imply that first n_sequential_peers of n_peers will be run sequentially"
  29. processes = []
  30. dht = {}
  31. swarm_maddrs = []
  32. info_queue = mp.Queue()
  33. info_lock = mp.RLock()
  34. for _ in range(n_sequential_peers):
  35. initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
  36. proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
  37. proc.start()
  38. processes.append(proc)
  39. node_id, peer_endpoint, peer_maddrs = info_queue.get()
  40. dht[peer_endpoint] = node_id
  41. swarm_maddrs.append(peer_maddrs)
  42. def collect_info():
  43. while True:
  44. node_id, peer_endpoint, peer_maddrs = info_queue.get()
  45. with info_lock:
  46. dht[peer_endpoint] = node_id
  47. swarm_maddrs.append(peer_maddrs)
  48. if len(dht) == n_peers:
  49. break
  50. collect_thread = threading.Thread(target=collect_info)
  51. collect_thread.start()
  52. for _ in range(n_peers - n_sequential_peers):
  53. with info_lock:
  54. initial_peers = random.choice(swarm_maddrs)
  55. proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
  56. proc.start()
  57. processes.append(proc)
  58. collect_thread.join()
  59. return processes, dht, swarm_maddrs
  60. async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
  61. nodes = [await DHTNode.create(**kwargs)]
  62. initial_peers = await nodes[0].get_visible_maddrs()
  63. nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
  64. return nodes