|
@@ -3,17 +3,17 @@ import torch
|
|
|
import tesseract
|
|
|
from .layers import name_to_block
|
|
|
from contextlib import contextmanager
|
|
|
+import multiprocessing as mp
|
|
|
|
|
|
|
|
|
-@contextmanager
|
|
|
-def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
|
|
|
+def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
|
|
|
expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
|
|
|
- no_network=False, initial_peers=(), network_port=None, verbose=False, **kwargs
|
|
|
+ no_network=False, initial_peers=(), network_port=None, verbose=False, start=True, **kwargs
|
|
|
) -> tesseract.TesseractServer:
|
|
|
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
|
if verbose and len(kwargs) == 0:
|
|
|
print("Ignored kwargs:", kwargs)
|
|
|
- expert_cls in name_to_block
|
|
|
+ assert expert_cls in name_to_block
|
|
|
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
|
|
|
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
@@ -21,7 +21,8 @@ def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
|
network = None
|
|
|
if not no_network:
|
|
|
initial_peers = eval(initial_peers)
|
|
|
- network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(), start=True)
|
|
|
+ network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(),
|
|
|
+ start=True)
|
|
|
if verbose:
|
|
|
print("Parsed initial peers:", initial_peers)
|
|
|
print(f"Running network node on port {network.port}")
|
|
@@ -37,19 +38,33 @@ def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
|
outputs_schema=tesseract.BatchTensorProto(hidden_dim),
|
|
|
max_batch_size=max_batch_size,
|
|
|
)
|
|
|
- # start server
|
|
|
+ # actually start server
|
|
|
server = tesseract.TesseractServer(
|
|
|
network, experts, addr=host, port=port or tesseract.find_open_port(),
|
|
|
conn_handler_processes=num_handlers, device=device)
|
|
|
- try:
|
|
|
+
|
|
|
+ if start:
|
|
|
server.run_in_background(await_ready=True)
|
|
|
if verbose:
|
|
|
print(f"Running server at {server.addr}:{server.port}")
|
|
|
print(f"Active experts of type {expert_cls}: {list(experts.keys())}")
|
|
|
+ return server
|
|
|
+
|
|
|
+
|
|
|
+def background_server(*args, verbose=True, **kwargs):
|
|
|
+ """ Runs server in a background process and returns a reference to it. """
|
|
|
+ try:
|
|
|
+ server = make_dummy_server(*args, verbose=verbose, start=False, **kwargs)
|
|
|
+
|
|
|
+ runner = mp.Process(target=lambda: (server.start(), server.join()))
|
|
|
+ runner.start()
|
|
|
+ server.ready.wait()
|
|
|
yield server
|
|
|
+ runner.join()
|
|
|
finally:
|
|
|
if verbose:
|
|
|
print("Shutting down server...")
|
|
|
server.shutdown()
|
|
|
+ runner.terminate()
|
|
|
if verbose:
|
|
|
print("Server shut down successfully.")
|