test_start_server.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import re
  3. from subprocess import PIPE, Popen
  4. from tempfile import TemporaryDirectory
  5. from hivemind.moe.server import background_server
  6. def test_background_server_identity_path():
  7. with TemporaryDirectory() as tempdir:
  8. id_path = os.path.join(tempdir, "id")
  9. with background_server(num_experts=1, identity_path=id_path) as server_info_1, background_server(
  10. num_experts=1, identity_path=id_path
  11. ) as server_info_2, background_server(num_experts=1, identity_path=None) as server_info_3:
  12. assert server_info_1.peer_id == server_info_2.peer_id
  13. assert server_info_1.peer_id != server_info_3.peer_id
  14. assert server_info_3.peer_id == server_info_3.peer_id
  15. def test_cli_run_server_identity_path():
  16. pattern = r"Running DHT node on \[(.+)\],"
  17. with TemporaryDirectory() as tempdir:
  18. id_path = os.path.join(tempdir, "id")
  19. server_1_proc = Popen(
  20. ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
  21. stderr=PIPE,
  22. text=True,
  23. encoding="utf-8",
  24. )
  25. # Skip line "Generating new identity (libp2p private key) in {path to file}"
  26. line = server_1_proc.stderr.readline()
  27. line = server_1_proc.stderr.readline()
  28. addrs_1 = set(re.search(pattern, line).group(1).split(", "))
  29. ids_1 = set(a.split("/")[-1] for a in addrs_1)
  30. assert len(ids_1) == 1
  31. server_2_proc = Popen(
  32. ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
  33. stderr=PIPE,
  34. text=True,
  35. encoding="utf-8",
  36. )
  37. line = server_2_proc.stderr.readline()
  38. addrs_2 = set(re.search(pattern, line).group(1).split(", "))
  39. ids_2 = set(a.split("/")[-1] for a in addrs_2)
  40. assert len(ids_2) == 1
  41. server_3_proc = Popen(
  42. ["hivemind-server", "--num_experts", "1"],
  43. stderr=PIPE,
  44. text=True,
  45. encoding="utf-8",
  46. )
  47. line = server_3_proc.stderr.readline()
  48. addrs_3 = set(re.search(pattern, line).group(1).split(", "))
  49. ids_3 = set(a.split("/")[-1] for a in addrs_3)
  50. assert len(ids_3) == 1
  51. assert ids_1 == ids_2
  52. assert ids_1 != ids_3
  53. assert ids_2 != ids_3
  54. assert addrs_1 != addrs_2
  55. assert addrs_1 != addrs_3
  56. assert addrs_2 != addrs_3
  57. server_1_proc.terminate()
  58. server_2_proc.terminate()
  59. server_3_proc.terminate()
  60. server_1_proc.wait()
  61. server_2_proc.wait()
  62. server_3_proc.wait()