Jelajahi Sumber

add server script

justheuristic 5 tahun lalu
induk
melakukan
4e6f600686
2 mengubah file dengan 74 tambahan dan 1 penghapusan
  1. 73 0
      scripts/start_server.py
  2. 1 1
      tesseract/network/__init__.py

+ 73 - 0
scripts/start_server.py

@@ -0,0 +1,73 @@
+import argparse
+import multiprocessing as mp
+import random
+import resource
+import os
+import sys
+import time
+
+import torch
+sys.path.append(os.path.dirname(__file__) + '/../tests')
+from test_utils import layers, find_open_port
+import tesseract
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--expert_cls', type=str, default='ffn', required=False)
+    parser.add_argument('--num_experts', type=int, default=1, required=False)
+    parser.add_argument('--num_handlers', type=int, default=None, required=False)
+    parser.add_argument('--hidden_dim', type=int, default=1024, required=False)
+    parser.add_argument('--max_batch_size', type=int, default=16384, required=False)
+    parser.add_argument('--expert_prefix', type=str, default='expert', required=False)
+    parser.add_argument('--expert_offset', type=int, default=0, required=False)
+    parser.add_argument('--device', type=str, default=None, required=False)
+    parser.add_argument('--port', type=int, default=None, required=False)
+    parser.add_argument('--host', type=str, default='0.0.0.0', required=False)
+    parser.add_argument('--no_network', action='store_true')
+    parser.add_argument('--initial_peers', type=str, default="[]", required=False)
+    parser.add_argument('--network_port', type=int, default=None, required=False)
+    parser.add_argument('--increase_file_limit', action='store_true')
+
+    args = parser.parse_args()
+    if args.increase_file_limit:
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+        try:
+            print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
+            resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
+        except:
+            print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
+
+    assert args.expert_cls in layers.name_to_block
+    num_handlers = args.num_handlers or args.num_experts * 8
+    port = args.port or find_open_port()
+    network_port = args.network_port or find_open_port()
+
+    device = args.device or ('cuda' if torch.cuda.is_available() else 'cpu')
+
+    # initialize network
+    network = None
+    if not args.no_network:
+        initial_peers = eval(args.initial_peers)
+        print("Parsed initial peers:", initial_peers)
+
+        network = tesseract.TesseractNetwork(*initial_peers, port=network_port, start=True)
+        print("Running network node on port {network_port}")
+
+    # initialize experts
+    experts = {}
+    for i in range(args.num_experts):
+        expert = torch.jit.script(layers.name_to_block[args.expert_cls](args.hidden_dim))
+        experts[f'expert{i}'] = tesseract.ExpertBackend(name=f'{args.expert_prefix}.{i + args.expert_offset}',
+                                                        expert=expert, opt=torch.optim.Adam(expert.parameters()),
+                                                        args_schema=(tesseract.BatchTensorProto(args.hidden_dim),),
+                                                        outputs_schema=tesseract.BatchTensorProto(args.hidden_dim),
+                                                        max_batch_size=args.max_batch_size,
+                                                        )
+    # start server
+    server = tesseract.TesseractServer(network, experts, port=port, conn_handler_processes=num_handlers, device=device)
+    print("Active expert uids:", sorted(experts))
+    try:
+        server.run()
+    finally:
+        server.shutdown()

+ 1 - 1
tesseract/network/__init__.py

@@ -69,7 +69,7 @@ class TesseractNetwork(mp.Process):
         :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
         """
         done_event = mp.Event() if wait_timeout else None
-        self.pipe.send(('_declare_experts', [], dict(uids=uids, addr=addr, port=port, done_event=done_event)))
+        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
         if done_event is not None:
             done_event.wait(wait_timeout)