dht_swarms.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import asyncio
  2. import multiprocessing as mp
  3. import random
  4. import signal
  5. import threading
  6. from typing import Dict, List, Tuple
  7. from multiaddr import Multiaddr
  8. from hivemind.dht.node import DHTID, Endpoint, DHTNode
  9. def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
  10. if asyncio.get_event_loop().is_running():
  11. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  12. asyncio.set_event_loop(asyncio.new_event_loop())
  13. loop = asyncio.get_event_loop()
  14. node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
  15. maddrs = loop.run_until_complete(node.get_visible_maddrs())
  16. info_queue.put((node.node_id, node.endpoint, maddrs))
  17. async def shutdown():
  18. await node.shutdown()
  19. loop.stop()
  20. loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
  21. loop.run_forever()
  22. def launch_swarm_in_separate_processes(n_peers: int, n_sequential_peers: int) -> \
  23. Tuple[List[mp.Process], Dict[Endpoint, DHTID], List[List[Multiaddr]]]:
  24. assert n_sequential_peers < n_peers, \
  25. 'Parameters imply that first n_sequential_peers of n_peers will be run sequentially'
  26. processes = []
  27. dht = {}
  28. swarm_maddrs = []
  29. info_queue = mp.Queue()
  30. info_lock = mp.RLock()
  31. for _ in range(n_sequential_peers):
  32. initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else []
  33. proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
  34. proc.start()
  35. processes.append(proc)
  36. node_id, peer_endpoint, peer_maddrs = info_queue.get()
  37. dht[peer_endpoint] = node_id
  38. swarm_maddrs.append(peer_maddrs)
  39. def collect_info():
  40. while True:
  41. node_id, peer_endpoint, peer_maddrs = info_queue.get()
  42. with info_lock:
  43. dht[peer_endpoint] = node_id
  44. swarm_maddrs.append(peer_maddrs)
  45. if len(dht) == n_peers:
  46. break
  47. collect_thread = threading.Thread(target=collect_info)
  48. collect_thread.start()
  49. for _ in range(n_peers - n_sequential_peers):
  50. with info_lock:
  51. initial_peers = random.choice(swarm_maddrs)
  52. proc = mp.Process(target=run_node, args=(initial_peers, info_queue), daemon=True)
  53. proc.start()
  54. processes.append(proc)
  55. collect_thread.join()
  56. return processes, dht, swarm_maddrs
  57. async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
  58. nodes = [await DHTNode.create(**kwargs)]
  59. initial_peers = await nodes[0].get_visible_maddrs()
  60. nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs)
  61. for _ in range(n_peers - 1)])
  62. return nodes