run_server.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import subprocess
  2. import torch
  3. import tesseract
  4. from .layers import name_to_block
  5. from contextlib import contextmanager
  6. import multiprocessing as mp
  7. def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
  8. expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device='cpu', no_optimizer=False,
  9. no_network=False, initial_peers=(), network_port=None, verbose=False, start=True, **kwargs
  10. ) -> tesseract.TesseractServer:
  11. """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
  12. if verbose and len(kwargs) != 0:
  13. print("Ignored kwargs:", kwargs)
  14. assert expert_cls in name_to_block
  15. num_handlers = num_handlers if num_handlers is not None else num_experts * 8
  16. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  17. # initialize network
  18. network = None
  19. if not no_network:
  20. initial_peers = eval(initial_peers)
  21. network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(),
  22. start=True)
  23. if verbose:
  24. print("Parsed initial peers:", initial_peers)
  25. print(f"Running network node on port {network.port}")
  26. # initialize experts
  27. experts = {}
  28. for i in range(num_experts):
  29. expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
  30. opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
  31. expert_uid = f'{expert_prefix}{i + expert_offset}'
  32. experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
  33. args_schema=(tesseract.BatchTensorProto(hidden_dim),),
  34. outputs_schema=tesseract.BatchTensorProto(hidden_dim),
  35. max_batch_size=max_batch_size,
  36. )
  37. # actually start server
  38. server = tesseract.TesseractServer(
  39. network, experts, addr=host, port=port or tesseract.find_open_port(),
  40. conn_handler_processes=num_handlers, device=device)
  41. if start:
  42. server.run_in_background(await_ready=True)
  43. if verbose:
  44. print(f"Server started at {server.addr}:{server.port}")
  45. print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
  46. return server
  47. @contextmanager
  48. def background_server(*args, verbose=True, **kwargs):
  49. """ Runs server in a background process and returns a reference to it. """
  50. recv_addr, send_addr = mp.Pipe(duplex=True)
  51. trigger_shutdown = mp.Event()
  52. def server_runner():
  53. try:
  54. server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
  55. send_addr.send((server.addr, server.port))
  56. trigger_shutdown.wait()
  57. finally:
  58. if verbose:
  59. print("Shutting down server...")
  60. trigger_shutdown.set() # if server failed internally, set the shutdown trigger anyway
  61. server.shutdown()
  62. if verbose:
  63. print("Server shut down successfully.")
  64. try:
  65. runner = mp.Process(target=server_runner)
  66. runner.start()
  67. yield recv_addr.recv() # yield tuple(hostname, port)
  68. finally:
  69. trigger_shutdown.set()
  70. runner.join()