run_server.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import torch
  2. import tesseract
  3. from .layers import name_to_block
  4. from contextlib import contextmanager
  5. import multiprocessing as mp
  6. def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
  7. expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device='cpu', no_optimizer=False,
  8. no_network=False, initial_peers=(), network_port=None, verbose=False, start=True, **kwargs
  9. ) -> tesseract.TesseractServer:
  10. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  11. if verbose and len(kwargs) != 0:
  12. print("Ignored kwargs:", kwargs)
  13. assert expert_cls in name_to_block
  14. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  15. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  16. # initialize network
  17. network = None
  18. if not no_network:
  19. initial_peers = eval(initial_peers)
  20. network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(),
  21. start=True)
  22. if verbose:
  23. print("Parsed initial peers:", initial_peers)
  24. print(f"Running network node on port {network.port}")
  25. # initialize experts
  26. experts = {}
  27. for i in range(num_experts):
  28. expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
  29. opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
  30. expert_uid = f'{expert_prefix}{i + expert_offset}'
  31. experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
  32. args_schema=(tesseract.BatchTensorProto(hidden_dim),),
  33. outputs_schema=tesseract.BatchTensorProto(hidden_dim),
  34. max_batch_size=max_batch_size,
  35. )
  36. # actually start server
  37. server = tesseract.TesseractServer(
  38. network, experts, addr=host, port=port or tesseract.find_open_port(),
  39. conn_handler_processes=num_handlers, device=device)
  40. if start:
  41. server.run_in_background(await_ready=True)
  42. if verbose:
  43. print(f"Server started at {server.addr}:{server.port}")
  44. print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
  45. return server
  46. @contextmanager
  47. def background_server(*args, verbose=True, **kwargs):
  48. """ Runs server in a background process and returns a reference to it. """
  49. recv_server, send_server = mp.Pipe(duplex=False)
  50. def server_runner():
  51. server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
  52. print('!!abouttosend')
  53. send_server.send(server)
  54. server.join()
  55. try:
  56. runner = mp.Process(target=server_runner)
  57. runner.start()
  58. print('!!waiting')
  59. server = recv_server.recv()
  60. print('!!received')
  61. yield server
  62. runner.join()
  63. finally:
  64. if verbose:
  65. print("Shutting down server...")
  66. server.shutdown()
  67. runner.terminate()
  68. if verbose:
  69. print("Server shut down successfully.")