|
@@ -20,6 +20,7 @@ from hivemind.server.layers import name_to_block, name_to_input
|
|
|
from hivemind.server.runtime import Runtime
|
|
|
from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
|
|
|
from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
|
|
|
+from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -69,7 +70,7 @@ class Server(threading.Thread):
|
|
|
def create(listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
|
|
|
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, num_handlers=None, max_batch_size=4096,
|
|
|
device=None, no_dht=False, initial_peers=(), dht_port=None, verbose=True,
|
|
|
- *, start: bool, **kwargs) -> Server:
|
|
|
+ compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
"""
|
|
|
Instantiate a server with several identical experts. See argparse comments below for details
|
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
@@ -118,9 +119,9 @@ class Server(threading.Thread):
|
|
|
|
|
|
sample_input = name_to_input[expert_cls](4, hidden_dim)
|
|
|
if isinstance(sample_input, tuple):
|
|
|
- args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
|
|
|
+ args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
|
|
|
else:
|
|
|
- args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
|
|
|
+ args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input, compression),)
|
|
|
|
|
|
# initialize experts
|
|
|
|
|
@@ -129,7 +130,8 @@ class Server(threading.Thread):
|
|
|
expert = name_to_block[expert_cls](hidden_dim)
|
|
|
experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert,
|
|
|
args_schema=args_schema,
|
|
|
- outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
|
|
|
+ outputs_schema=hivemind.BatchTensorDescriptor(
|
|
|
+ hidden_dim, compression=compression),
|
|
|
opt=optim_cls(expert.parameters()),
|
|
|
max_batch_size=max_batch_size,
|
|
|
)
|