test_start_server.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. import re
  3. from functools import partial
  4. from subprocess import PIPE, Popen
  5. from tempfile import TemporaryDirectory
  6. from hivemind.moe.server import background_server
  7. def test_background_server_identity_path():
  8. with TemporaryDirectory() as tempdir:
  9. id_path = os.path.join(tempdir, "id")
  10. server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1)
  11. with server_runner(identity_path=id_path) as server_info_1, server_runner(
  12. identity_path=id_path
  13. ) as server_info_2, server_runner(identity_path=None) as server_info_3:
  14. assert server_info_1.peer_id == server_info_2.peer_id
  15. assert server_info_1.peer_id != server_info_3.peer_id
  16. assert server_info_3.peer_id == server_info_3.peer_id
  17. def test_cli_run_server_identity_path():
  18. pattern = r"Running DHT node on \[(.+)\],"
  19. with TemporaryDirectory() as tempdir:
  20. id_path = os.path.join(tempdir, "id")
  21. server_1_proc = Popen(
  22. ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
  23. stderr=PIPE,
  24. text=True,
  25. encoding="utf-8",
  26. )
  27. # Skip line "Generating new identity (libp2p private key) in {path to file}"
  28. server_1_proc.stderr.readline()
  29. line = server_1_proc.stderr.readline()
  30. addrs_pattern_result = re.search(pattern, line)
  31. assert addrs_pattern_result is not None, line
  32. addrs_1 = set(addrs_pattern_result.group(1).split(", "))
  33. ids_1 = set(a.split("/")[-1] for a in addrs_1)
  34. assert len(ids_1) == 1
  35. server_2_proc = Popen(
  36. ["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
  37. stderr=PIPE,
  38. text=True,
  39. encoding="utf-8",
  40. )
  41. line = server_2_proc.stderr.readline()
  42. addrs_pattern_result = re.search(pattern, line)
  43. assert addrs_pattern_result is not None, line
  44. addrs_2 = set(addrs_pattern_result.group(1).split(", "))
  45. ids_2 = set(a.split("/")[-1] for a in addrs_2)
  46. assert len(ids_2) == 1
  47. server_3_proc = Popen(
  48. ["hivemind-server", "--num_experts", "1"],
  49. stderr=PIPE,
  50. text=True,
  51. encoding="utf-8",
  52. )
  53. line = server_3_proc.stderr.readline()
  54. addrs_pattern_result = re.search(pattern, line)
  55. assert addrs_pattern_result is not None, line
  56. addrs_3 = set(addrs_pattern_result.group(1).split(", "))
  57. ids_3 = set(a.split("/")[-1] for a in addrs_3)
  58. assert len(ids_3) == 1
  59. assert ids_1 == ids_2
  60. assert ids_1 != ids_3
  61. assert ids_2 != ids_3
  62. assert addrs_1 != addrs_2
  63. assert addrs_1 != addrs_3
  64. assert addrs_2 != addrs_3
  65. server_1_proc.terminate()
  66. server_2_proc.terminate()
  67. server_3_proc.terminate()
  68. server_1_proc.wait()
  69. server_2_proc.wait()
  70. server_3_proc.wait()