1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import multiprocessing as mp
- import os
- import threading
- from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
- from typing import Dict
- from warnings import warn
- from .connection_handler import handle_connection
- from .network_handler import NetworkHandlerThread
- from ..network import TesseractNetwork
- from ..runtime import TesseractRuntime, ExpertBackend
- class TesseractServer(threading.Thread):
- def __init__(self, network: TesseractNetwork, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
- port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
- **kwargs):
- super().__init__()
- self.network, self.experts, self.update_period = network, expert_backends, update_period
- self.addr, self.port = addr, port
- self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
- self.runtime = TesseractRuntime(self.experts, **kwargs)
- if start:
- self.start()
- def run(self):
- if self.network:
- if not self.network.is_alive():
- self.network.start()
- network_thread = NetworkHandlerThread(experts=self.experts, network=self.network,
- addr=self.addr, port=self.port, update_period=self.update_period)
- network_thread.start()
- for process in self.conn_handlers:
- if not process.is_alive():
- process.start()
- self.runtime.run()
- for process in self.conn_handlers:
- process.join()
- if self.network:
- network_thread.join()
- @property
- def ready(self):
- return self.runtime.ready # mp.Event that is true if self is ready to process batches
- def _create_connection_handlers(self, num_handlers):
- sock = socket(AF_INET, SOCK_STREAM)
- sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- sock.bind(('', self.port))
- sock.listen()
- sock.settimeout(self.update_period)
- processes = [mp.Process(target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts))
- for i in range(num_handlers)]
- return processes
- def shutdown(self):
- """ Gracefully terminate a tesseract server, process-safe """
- self.runtime.shutdown()
- for process in self.conn_handlers:
- process.terminate()
- warn("TODO shutdown network")
- def socket_loop(sock, experts):
- """ catch connections, send tasks to processing, respond with results """
- print(f'Spawned connection handler pid={os.getpid()}')
- while True:
- try:
- handle_connection(sock.accept(), experts)
- except KeyboardInterrupt as e:
- print(f'Socket loop has caught {type(e)}, exiting')
- break
- except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError):
- continue
|