12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import torch
- import tesseract
- from .layers import name_to_block
- from contextlib import contextmanager
- @contextmanager
- def background_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
- expert_prefix='expert.', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
- no_network=False, initial_peers=(), network_port=None, verbose=False, **kwargs
- ) -> tesseract.TesseractServer:
- """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
- if verbose and len(kwargs) == 0:
- print("Ignored kwargs:", kwargs)
- expert_cls in name_to_block
- num_handlers = num_handlers if num_handlers is not None else num_experts * 8
- device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
- # initialize network
- network = None
- if not no_network:
- initial_peers = eval(initial_peers)
- network = tesseract.TesseractNetwork(*initial_peers, port=network_port or tesseract.find_open_port(), start=True)
- if verbose:
- print("Parsed initial peers:", initial_peers)
- print(f"Running network node on port {network.port}")
- # initialize experts
- experts = {}
- for i in range(num_experts):
- expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
- opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
- expert_uid = f'{expert_prefix}.{i + expert_offset}'
- experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
- args_schema=(tesseract.BatchTensorProto(hidden_dim),),
- outputs_schema=tesseract.BatchTensorProto(hidden_dim),
- max_batch_size=max_batch_size,
- )
- # start server
- server = tesseract.TesseractServer(
- network, experts, addr=host, port=port or tesseract.find_open_port(),
- conn_handler_processes=num_handlers, device=device)
- try:
- server.run_in_background(await_ready=True)
- if verbose:
- print(f"Running server at {server.addr}:{server.port}")
- print(f"Active experts of type {expert_cls}: {list(experts.keys())}")
- yield server
- finally:
- if verbose:
- print("Shutting down server...")
- server.shutdown()
- if verbose:
- print("Server shut down successfully.")
|