test_start_server.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. import re
  3. from functools import partial
  4. from subprocess import PIPE, Popen
  5. from tempfile import TemporaryDirectory
  6. from hivemind.moe.server import background_server
  7. def test_background_server_identity_path():
  8. with TemporaryDirectory() as tempdir:
  9. id_path = os.path.join(tempdir, "id")
  10. server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1)
  11. with server_runner(identity_path=id_path) as server_info_1, server_runner(
  12. identity_path=id_path
  13. ) as server_info_2, server_runner(identity_path=None) as server_info_3:
  14. assert server_info_1.peer_id == server_info_2.peer_id
  15. assert server_info_1.peer_id != server_info_3.peer_id
  16. assert server_info_3.peer_id == server_info_3.peer_id
  17. def test_cli_run_server_identity_path():
  18. pattern = r"Running DHT node on \[(.+)\],"
  19. with TemporaryDirectory() as tempdir:
  20. id_path = os.path.join(tempdir, "id")
  21. server_1_proc = Popen(
  22. ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
  23. stderr=PIPE,
  24. text=True,
  25. encoding="utf-8",
  26. )
  27. line = server_1_proc.stderr.readline()
  28. assert "Generating new identity" in line
  29. line = server_1_proc.stderr.readline()
  30. addrs_pattern_result = re.search(pattern, line)
  31. assert addrs_pattern_result is not None, line
  32. addrs_1 = set(addrs_pattern_result.group(1).split(", "))
  33. ids_1 = set(a.split("/")[-1] for a in addrs_1)
  34. assert len(ids_1) == 1
  35. server_2_proc = Popen(
  36. ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
  37. stderr=PIPE,
  38. text=True,
  39. encoding="utf-8",
  40. )
  41. line = server_2_proc.stderr.readline()
  42. assert re.search(r"Checking that identity.+is not used by other peers", line) is not None
  43. line = server_2_proc.stderr.readline()
  44. addrs_pattern_result = re.search(pattern, line)
  45. assert addrs_pattern_result is not None, line
  46. addrs_2 = set(addrs_pattern_result.group(1).split(", "))
  47. ids_2 = set(a.split("/")[-1] for a in addrs_2)
  48. assert len(ids_2) == 1
  49. server_3_proc = Popen(
  50. ["hivemind-server", "--num_experts", "1"],
  51. stderr=PIPE,
  52. text=True,
  53. encoding="utf-8",
  54. )
  55. line = server_3_proc.stderr.readline()
  56. addrs_pattern_result = re.search(pattern, line)
  57. assert addrs_pattern_result is not None, line
  58. addrs_3 = set(addrs_pattern_result.group(1).split(", "))
  59. ids_3 = set(a.split("/")[-1] for a in addrs_3)
  60. assert len(ids_3) == 1
  61. assert ids_1 == ids_2
  62. assert ids_1 != ids_3
  63. assert ids_2 != ids_3
  64. assert addrs_1 != addrs_2
  65. assert addrs_1 != addrs_3
  66. assert addrs_2 != addrs_3
  67. server_1_proc.terminate()
  68. server_2_proc.terminate()
  69. server_3_proc.terminate()
  70. server_1_proc.wait()
  71. server_2_proc.wait()
  72. server_3_proc.wait()