|
@@ -4,14 +4,14 @@ import multiprocessing as mp
|
|
import argparse
|
|
import argparse
|
|
|
|
|
|
import torch
|
|
import torch
|
|
-import tesseract
|
|
|
|
|
|
+import hivemind
|
|
from .layers import name_to_block
|
|
from .layers import name_to_block
|
|
|
|
|
|
|
|
|
|
def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
|
|
def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024, num_handlers=None,
|
|
expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
|
|
expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None, no_optimizer=False,
|
|
no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True, start=False,
|
|
no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True, start=False,
|
|
- UID_DELIMETER=tesseract.DHTNode.UID_DELIMETER, **kwargs) -> tesseract.TesseractServer:
|
|
|
|
|
|
+ UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER, **kwargs) -> hivemind.Server:
|
|
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
""" A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
|
|
if verbose and len(kwargs) != 0:
|
|
if verbose and len(kwargs) != 0:
|
|
print("Ignored kwargs:", kwargs)
|
|
print("Ignored kwargs:", kwargs)
|
|
@@ -24,8 +24,8 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
if not no_dht:
|
|
if not no_dht:
|
|
if not len(initial_peers):
|
|
if not len(initial_peers):
|
|
print("No initial peers provided. Starting additional dht as an initial peer.")
|
|
print("No initial peers provided. Starting additional dht as an initial peer.")
|
|
- dht_root = tesseract.DHTNode(
|
|
|
|
- *initial_peers, port=root_port or tesseract.find_open_port(), start=True)
|
|
|
|
|
|
+ dht_root = hivemind.DHTNode(
|
|
|
|
+ *initial_peers, port=root_port or hivemind.find_open_port(), start=True)
|
|
print(f"Initializing DHT with port {dht_root.port}")
|
|
print(f"Initializing DHT with port {dht_root.port}")
|
|
initial_peers = (('localhost', dht_root.port), )
|
|
initial_peers = (('localhost', dht_root.port), )
|
|
else:
|
|
else:
|
|
@@ -33,8 +33,8 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
if root_port is not None:
|
|
if root_port is not None:
|
|
print(f"Warning: root_port={root_port} will not be used since we already have peers.")
|
|
print(f"Warning: root_port={root_port} will not be used since we already have peers.")
|
|
|
|
|
|
- dht = tesseract.DHTNode(
|
|
|
|
- *initial_peers, port=dht_port or tesseract.find_open_port(), start=True)
|
|
|
|
|
|
+ dht = hivemind.DHTNode(
|
|
|
|
+ *initial_peers, port=dht_port or hivemind.find_open_port(), start=True)
|
|
if verbose:
|
|
if verbose:
|
|
print(f"Running dht node on port {dht.port}")
|
|
print(f"Running dht node on port {dht.port}")
|
|
|
|
|
|
@@ -44,14 +44,14 @@ def make_dummy_server(host='0.0.0.0', port=None, num_experts=1, expert_cls='ffn'
|
|
expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
|
|
expert = torch.jit.script(name_to_block[expert_cls](hidden_dim))
|
|
opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
|
|
opt = torch.optim.SGD(expert.parameters(), 0.0) if no_optimizer else torch.optim.Adam(expert.parameters())
|
|
expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
|
|
expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
|
|
- experts[expert_uid] = tesseract.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
|
|
|
|
- args_schema=(tesseract.BatchTensorProto(hidden_dim),),
|
|
|
|
- outputs_schema=tesseract.BatchTensorProto(hidden_dim),
|
|
|
|
- max_batch_size=max_batch_size,
|
|
|
|
- )
|
|
|
|
|
|
+ experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
|
|
|
|
+ args_schema=(hivemind.BatchTensorProto(hidden_dim),),
|
|
|
|
+ outputs_schema=hivemind.BatchTensorProto(hidden_dim),
|
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
|
+ )
|
|
# actually start server
|
|
# actually start server
|
|
- server = tesseract.TesseractServer(
|
|
|
|
- dht, experts, addr=host, port=port or tesseract.find_open_port(),
|
|
|
|
|
|
+ server = hivemind.Server(
|
|
|
|
+ dht, experts, addr=host, port=port or hivemind.find_open_port(),
|
|
conn_handler_processes=num_handlers, device=device)
|
|
conn_handler_processes=num_handlers, device=device)
|
|
|
|
|
|
if start:
|
|
if start:
|