123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- import torch
- import tesseract
- from .layers import name_to_block
- from contextlib import contextmanager
- import multiprocessing as mp
- def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
- expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
- no_network=False, initial_peers=(), network_port=None, verbose=False, start=True, **kwargs
- ) -> tesseract.TesseractServer:
- """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
- if verbose and len(kwargs) != 0:
- print("Ignored kwargs:", kwargs)
- assert expert_cls in name_to_block
- num_handlers = num_handlers if num_handlers is not None else num_experts * 8
- device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
- # initialize network
- network = None
- if not no_network:
- initial_peers = eval(initial_peers)
- network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(),
- start=True)
- if verbose:
- print("Parsed initial peers:", initial_peers)
- print(f"Running network node on port {network.port}")
- # initialize experts
- experts = {}
- for i in range(num_experts):
- expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
- opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
- expert_uid = f'{expert_prefix}{i + expert_offset}'
- experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
- args_schema=(tesseract.BatchTensorProto(hidden_dim),),
- outputs_schema=tesseract.BatchTensorProto(hidden_dim),
- max_batch_size=max_batch_size,
- )
- # actually start server
- server = tesseract.TesseractServer(
- network, experts, addr=host, port=port or tesseract.find_open_port(),
- conn_handler_processes=num_handlers, device=device)
- if start:
- server.run_in_background(await_ready=True)
- if verbose:
- print(f"Server started at {server.addr}:{server.port}")
- print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
- return server
- @contextmanager
- def background_server(*args, verbose=True, **kwargs):
- """ Runs server in a background process and returns a reference to it. """
- recv_server, send_server = mp.Pipe(duplex=False)
- def server_runner():
- server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
- send_server.send(server)
- server.join()
- try:
- runner = mp.Process(target=server_runner)
- runner.start()
- print('!!waiting')
- server = recv_server.recv()
- print('!!received')
- yield server
- runner.join()
- finally:
- if verbose:
- print("Shutting down server...")
- server.shutdown()
- runner.terminate()
- if verbose:
- print("Server shut down successfully.")
|