Bläddra i källkod

separate dht script

justheuristic 5 år sedan
förälder
incheckning
3ceb24d07d

+ 27 - 0
scripts/run_dht.py

@@ -0,0 +1,27 @@
+import argparse
+import resource
+import os
+import sys
+
+import torch
+import tesseract
+from tesseract.utils import find_open_port
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--port', type=int, default=None, required=False)
+    parser.add_argument('--initial_peers', type=str, default="[]", required=False)
+    parser.add_argument('--network_port', type=int, default=None, required=False)
+
+    args = parser.parse_args()
+    initial_peers = eval(args.initial_peers)
+    print("Parsed initial peers:", initial_peers)
+
+    network = tesseract.TesseractNetwork(*initial_peers, port=args.network_port or find_open_port(), start=False)
+    print(f"Running network node on port {network.port}")
+
+    try:
+        network.run()
+    finally:
+        network.shutdown()

+ 2 - 1
scripts/start_server.py → scripts/run_server.py

@@ -7,7 +7,8 @@ import torch
 import tesseract
 
 sys.path.append(os.path.dirname(__file__) + '/../tests')
-from test_utils import layers, find_open_port
+from test_utils import layers
+from tesseract import find_open_port
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()

+ 9 - 0
tesseract/utils/connection.py

@@ -52,3 +52,12 @@ class Connection(AbstractContextManager):
 
     def __exit__(self, *exc_info):
         self.conn.close()
+
+
+def find_open_port():
+    try:
+        sock = socket()
+        sock.bind(('', 0))
+        return sock.getsockname()[1]
+    except:
+        raise ValueError("Could not find open port")

+ 2 - 1
tests/benchmark_throughput.py

@@ -6,7 +6,8 @@ import sys
 import time
 
 import torch
-from test_utils import layers, print_device_info, find_open_port
+from test_utils import layers, print_device_info
+from tesseract import find_open_port
 
 import tesseract
 

+ 0 - 11
tests/test_utils/__init__.py

@@ -1,5 +1,3 @@
-from socket import socket
-
 import torch
 
 
@@ -14,12 +12,3 @@ def print_device_info(device=None):
         print('Memory Usage:')
         print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
         print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
-
-
-def find_open_port():
-    try:
-        sock = socket()
-        sock.bind(('', 0))
-        return sock.getsockname()[1]
-    except:
-        raise ValueError("Could not find open port")