__init__.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import multiprocessing as mp
  2. import os
  3. import threading
  4. from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
  5. from typing import Dict
  6. from warnings import warn
  7. from .connection_handler import handle_connection
  8. from .network_handler import NetworkHandlerThread
  9. from ..network import TesseractNetwork
  10. from ..runtime import TesseractRuntime, ExpertBackend
  11. class TesseractServer(threading.Thread):
  12. def __init__(self, network: TesseractNetwork, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
  13. port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
  14. **kwargs):
  15. super().__init__()
  16. self.network, self.experts, self.update_period = network, expert_backends, update_period
  17. self.addr, self.port = addr, port
  18. self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
  19. self.runtime = TesseractRuntime(self.experts, **kwargs)
  20. if start:
  21. self.start()
  22. def run(self):
  23. if self.network:
  24. if not self.network.is_alive():
  25. self.network.start()
  26. network_thread = NetworkHandlerThread(experts=self.experts, network=self.network,
  27. addr=self.addr, port=self.port, update_period=self.update_period)
  28. network_thread.start()
  29. for process in self.conn_handlers:
  30. if not process.is_alive():
  31. process.start()
  32. self.runtime.run()
  33. for process in self.conn_handlers:
  34. process.join()
  35. if self.network:
  36. network_thread.join()
  37. @property
  38. def ready(self):
  39. return self.runtime.ready # mp.Event that is true if self is ready to process batches
  40. def _create_connection_handlers(self, num_handlers):
  41. sock = socket(AF_INET, SOCK_STREAM)
  42. sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
  43. sock.bind(('', self.port))
  44. sock.listen()
  45. sock.settimeout(self.update_period)
  46. processes = [mp.Process(target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts))
  47. for i in range(num_handlers)]
  48. return processes
  49. def shutdown(self):
  50. """ Gracefully terminate a tesseract server, process-safe """
  51. self.runtime.shutdown()
  52. for process in self.conn_handlers:
  53. process.terminate()
  54. warn("TODO shutdown network")
  55. def socket_loop(sock, experts):
  56. """ catch connections, send tasks to processing, respond with results """
  57. print(f'Spawned connection handler pid={os.getpid()}')
  58. while True:
  59. try:
  60. handle_connection(sock.accept(), experts)
  61. except KeyboardInterrupt as e:
  62. print(f'Socket loop has caught {type(e)}, exiting')
  63. break
  64. except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError):
  65. continue