__init__.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import multiprocessing as mp
  2. import os
  3. import threading
  4. from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
  5. from typing import Dict
  6. from .connection_handler import handle_connection
  7. from .network_handler import NetworkHandlerThread
  8. from ..network import TesseractNetwork
  9. from ..runtime import TesseractRuntime, ExpertBackend
  10. class TesseractServer(threading.Thread):
  11. """
  12. TesseractServer allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
  13. After creation, a server should be started: see TesseractServer.run or TesseractServer.run_in_background.
  14. A working server does 3 things:
  15. - processes incoming forward/backward requests via TesseractRuntime (created by the server)
  16. - publishes updates to expert status every :update_period: seconds
  17. - follows orders from TesseractController - if it exists
  18. :type network: TesseractNetwork or None. Server with network=None will NOT be visible from DHT,
  19. but it will still support accessing experts directly with RemoteExpert(uid=UID, host=IPADDR, port=PORT).
  20. :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
  21. :param addr: server's network address that determines how it can be accessed. Default is local connections only.
  22. :param port: port to which server listens for requests such as expert forward or backward pass.
  23. :param conn_handler_processes: maximum number of simultaneous requests. Please note that the default value of 1
  24. if too small for normal functioning, we recommend 4 handlers per expert backend.
  25. :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
  26. if network is None, this parameter is ignored.
  27. :param start: if True, the server will immediately start as a background thread and returns control after server
  28. is ready (see .ready below)
  29. """
  30. def __init__(self, network: TesseractNetwork, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
  31. port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
  32. **kwargs):
  33. super().__init__()
  34. self.network, self.experts, self.update_period = network, expert_backends, update_period
  35. self.addr, self.port = addr, port
  36. self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
  37. self.runtime = TesseractRuntime(self.experts, **kwargs)
  38. if start:
  39. self.run_in_background(await_ready=True)
  40. def run(self):
  41. """
  42. Starts TesseractServer in the current thread. Initializes network if necessary, starts connection handlers,
  43. runs TesseractRuntime (self.runtime) to process incoming requests.
  44. """
  45. if self.network:
  46. if not self.network.is_alive():
  47. self.network.run_in_background(await_ready=True)
  48. network_thread = NetworkHandlerThread(experts=self.experts, network=self.network,
  49. addr=self.addr, port=self.port, update_period=self.update_period)
  50. network_thread.start()
  51. for process in self.conn_handlers:
  52. if not process.is_alive():
  53. process.start()
  54. self.runtime.run()
  55. for process in self.conn_handlers:
  56. process.join()
  57. if self.network:
  58. network_thread.join()
  59. def run_in_background(self, await_ready=True, timeout=None):
  60. """
  61. Starts TesseractServer in a background thread. if await_ready, this method will wait until background server
  62. is ready to process incoming requests or for :timeout: seconds max.
  63. """
  64. self.start()
  65. if await_ready and not self.ready.wait(timeout=timeout):
  66. raise TimeoutError("TesseractServer didn't notify .ready in {timeout} seconds")
  67. @property
  68. def ready(self) -> mp.synchronize.Event:
  69. """
  70. An event (multiprocessing.Event) that is set when the server is ready to process requests.
  71. Example
  72. =======
  73. >>> server.start()
  74. >>> server.ready.wait(timeout=10)
  75. >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
  76. """
  77. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  78. def _create_connection_handlers(self, num_handlers):
  79. sock = socket(AF_INET, SOCK_STREAM)
  80. sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
  81. sock.bind(('', self.port))
  82. sock.listen()
  83. sock.settimeout(self.update_period)
  84. processes = [mp.Process(target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts))
  85. for i in range(num_handlers)]
  86. return processes
  87. def shutdown(self):
  88. """
  89. Gracefully terminate a tesseract server, process-safe.
  90. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
  91. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
  92. """
  93. self.ready.clear()
  94. for process in self.conn_handlers:
  95. process.terminate()
  96. if self.network is not None:
  97. self.network.shutdown()
  98. self.runtime.shutdown()
  99. def socket_loop(sock, experts):
  100. """ catch connections, send tasks to processing, respond with results """
  101. print(f'Spawned connection handler pid={os.getpid()}')
  102. while True:
  103. try:
  104. handle_connection(sock.accept(), experts)
  105. except KeyboardInterrupt as e:
  106. print(f'Socket loop has caught {type(e)}, exiting')
  107. break
  108. except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError):
  109. continue