|
@@ -26,6 +26,7 @@ from hivemind.moe.server.layers import (
|
|
schedule_name_to_scheduler,
|
|
schedule_name_to_scheduler,
|
|
)
|
|
)
|
|
from hivemind.moe.server.runtime import Runtime
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
|
+from hivemind.optim import CollaborativeOptimizer
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
|
|
from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
|
|
|
|
|
|
@@ -107,6 +108,9 @@ class Server(threading.Thread):
|
|
num_handlers=None,
|
|
num_handlers=None,
|
|
min_batch_size=1,
|
|
min_batch_size=1,
|
|
max_batch_size=4096,
|
|
max_batch_size=4096,
|
|
|
|
+ use_averaging: bool = False,
|
|
|
|
+ averaging_target_batch_size: Optional[int] = None,
|
|
|
|
+ averaging_target_group_size: Optional[int] = None,
|
|
device=None,
|
|
device=None,
|
|
no_dht=False,
|
|
no_dht=False,
|
|
initial_peers=(),
|
|
initial_peers=(),
|
|
@@ -122,13 +126,17 @@ class Server(threading.Thread):
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
:param num_experts: run this many identical experts
|
|
:param num_experts: run this many identical experts
|
|
:param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
|
|
:param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
|
|
- means "sample random experts between myprefix.0.0 and myprefix.255.255;
|
|
|
|
|
|
+ means "sample random experts between myprefix.0.0 and myprefix.255.255;
|
|
:param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
|
|
:param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
|
|
:param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
|
|
:param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
|
|
:param hidden_dim: main dimension for expert_cls
|
|
:param hidden_dim: main dimension for expert_cls
|
|
:param num_handlers: server will use this many parallel processes to handle incoming requests
|
|
:param num_handlers: server will use this many parallel processes to handle incoming requests
|
|
:param min_batch_size: total num examples in the same batch will be greater than this value
|
|
:param min_batch_size: total num examples in the same batch will be greater than this value
|
|
:param max_batch_size: total num examples in the same batch will not exceed this value
|
|
:param max_batch_size: total num examples in the same batch will not exceed this value
|
|
|
|
+ :param use_averaging: Whether to use decentralized parameter and gradient averaging by wrapping the optimizer
|
|
|
|
+ with CollaborativeOptimizer
|
|
|
|
+ :param averaging_target_batch_size: number of examples to accumulate across all peers before averaging
|
|
|
|
+ :param averaging_target_group_size: target group size for decentralized averaging
|
|
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
|
|
:param device: all experts will use this device in torch notation; default: cuda if available else cpu
|
|
|
|
|
|
:param optim_cls: uses this optimizer to train all experts
|
|
:param optim_cls: uses this optimizer to train all experts
|
|
@@ -184,7 +192,11 @@ class Server(threading.Thread):
|
|
uids_to_generate = num_experts - len(expert_uids)
|
|
uids_to_generate = num_experts - len(expert_uids)
|
|
if uids_to_generate > 0:
|
|
if uids_to_generate > 0:
|
|
logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
|
|
logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
|
|
- expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
|
|
|
|
|
|
+ expert_uids.extend(
|
|
|
|
+ generate_uids_from_pattern(
|
|
|
|
+ uids_to_generate, expert_pattern, dht, remove_duplicates=not use_averaging
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
|
|
num_experts = len(expert_uids)
|
|
num_experts = len(expert_uids)
|
|
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
|
|
num_handlers = num_handlers if num_handlers is not None else num_experts * 8
|
|
@@ -203,11 +215,27 @@ class Server(threading.Thread):
|
|
experts = {}
|
|
experts = {}
|
|
for expert_uid in expert_uids:
|
|
for expert_uid in expert_uids:
|
|
expert = name_to_block[expert_cls](hidden_dim)
|
|
expert = name_to_block[expert_cls](hidden_dim)
|
|
|
|
+
|
|
|
|
+ optim = optim_cls(expert.parameters())
|
|
|
|
+ if use_averaging:
|
|
|
|
+ assert averaging_target_batch_size is not None
|
|
|
|
+ assert averaging_target_group_size is not None
|
|
|
|
+ optim = CollaborativeOptimizer(
|
|
|
|
+ optim,
|
|
|
|
+ dht=dht,
|
|
|
|
+ prefix=expert_uid.replace(".", ""),
|
|
|
|
+ compression_type=compression,
|
|
|
|
+ target_batch_size=averaging_target_batch_size,
|
|
|
|
+ target_group_size=averaging_target_group_size,
|
|
|
|
+ reuse_grad_buffers=True,
|
|
|
|
+ start=True,
|
|
|
|
+ )
|
|
|
|
+
|
|
experts[expert_uid] = hivemind.ExpertBackend(
|
|
experts[expert_uid] = hivemind.ExpertBackend(
|
|
name=expert_uid,
|
|
name=expert_uid,
|
|
expert=expert,
|
|
expert=expert,
|
|
args_schema=args_schema,
|
|
args_schema=args_schema,
|
|
- optimizer=optim_cls(expert.parameters()),
|
|
|
|
|
|
+ optimizer=optim,
|
|
scheduler=scheduler,
|
|
scheduler=scheduler,
|
|
num_warmup_steps=num_warmup_steps,
|
|
num_warmup_steps=num_warmup_steps,
|
|
num_total_steps=num_total_steps,
|
|
num_total_steps=num_total_steps,
|