start_server.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import argparse
  2. import multiprocessing as mp
  3. import random
  4. import resource
  5. import os
  6. import sys
  7. import time
  8. import torch
  9. sys.path.append(os.path.dirname(__file__) + '/../tests')
  10. from test_utils import layers, find_open_port
  11. import tesseract
  12. if __name__ == "__main__":
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
  15. parser.add_argument('--num_experts', type=int, default=1, required=False)
  16. parser.add_argument('--num_handlers', type=int, default=None, required=False)
  17. parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
  18. parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
  19. parser.add_argument('--expert_prefix', type=str, default='expert', required=False)
  20. parser.add_argument('--expert_offset', type=int, default=0, required=False)
  21. parser.add_argument('--device', type=str, default=None, required=False)
  22. parser.add_argument('--port', type=int, default=None, required=False)
  23. parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
  24. parser.add_argument('--no_network', action='store_true')
  25. parser.add_argument('--initial_peers', type=str, default="[]", required=False)
  26. parser.add_argument('--network_port', type=int, default=None, required=False)
  27. parser.add_argument('--increase_file_limit', action='store_true')
  28. args = parser.parse_args()
  29. if args.increase_file_limit:
  30. soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
  31. try:
  32. print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
  33. resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
  34. except:
  35. print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
  36. assert args.expert_cls in layers.name_to_block
  37. args.num_handlers = args.num_handlers or args.num_experts * 8
  38. device = args.device or ('cuda' if torch.cuda.is_available() else 'cpu')
  39. # initialize network
  40. network = None
  41. if not args.no_network:
  42. initial_peers = eval(args.initial_peers)
  43. print("Parsed initial peers:", initial_peers)
  44. network = tesseract.TesseractNetwork(*initial_peers, port=args.network_port or find_open_port(), start=True)
  45. print(f"Running network node on port {network.port}")
  46. # initialize experts
  47. experts = {}
  48. for i in range(args.num_experts):
  49. expert = torch.jit.script(layers.name_to_block[args.expert_cls](args.hidden_dim))
  50. expert_uid = f'{args.expert_prefix}.{i + args.expert_offset}'
  51. experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid,
  52. expert=expert, opt=torch.optim.Adam(expert.parameters()),
  53. args_schema=(tesseract.BatchTensorProto(args.hidden_dim),),
  54. outputs_schema=tesseract.BatchTensorProto(args.hidden_dim),
  55. max_batch_size=args.max_batch_size,
  56. )
  57. # start server
  58. server = tesseract.TesseractServer(
  59. network, experts, addr=args.host, port=args.port or find_open_port(),
  60. conn_handler_processes=args.num_handlers, device=device)
  61. print(f"Running server at {server.addr}:{server.port}")
  62. print(f"Active expert uids: {experts}")
  63. try:
  64. server.run()
  65. finally:
  66. server.shutdown()