connection_handler.py 1.1 KB

1234567891011121314151617181920212223242526272829
  1. from socket import socket
  2. from typing import Tuple, Dict
  3. from tesseract.runtime.expert_backend import ExpertBackend
  4. from tesseract.utils import PytorchSerializer, Connection
  5. def handle_connection(connection_tuple: Tuple[socket, str], experts: Dict[str, ExpertBackend]):
  6. with Connection(*connection_tuple) as connection:
  7. try:
  8. header = connection.recv_header()
  9. payload = PytorchSerializer.loads(connection.recv_raw())
  10. if header == 'fwd_':
  11. uid, inputs = payload
  12. response = experts[uid].forward_pool.submit_task(*inputs).result()
  13. elif header == 'bwd_':
  14. uid, inputs_and_grad_outputs = payload
  15. response = experts[uid].backward_pool.submit_task(*inputs_and_grad_outputs).result()
  16. elif header == 'info':
  17. uid = payload
  18. response = experts[uid].get_info()
  19. else:
  20. raise NotImplementedError(f"Unknown header: {header}")
  21. connection.send_raw('rest', PytorchSerializer.dumps(response))
  22. except RuntimeError:
  23. # socket connection broken
  24. pass