|
@@ -70,9 +70,10 @@ class Server(threading.Thread):
|
|
|
@classmethod
|
|
|
def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
|
|
|
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
|
|
|
- num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
|
|
|
- no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
|
|
|
- compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
|
|
|
+ num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
|
|
|
+ device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
|
|
|
+ compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
|
|
|
+ **kwargs) -> Server:
|
|
|
"""
|
|
|
Instantiate a server with several identical experts. See argparse comments below for details
|
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
@@ -89,7 +90,8 @@ class Server(threading.Thread):
|
|
|
:param optim_cls: uses this optimizer to train all experts
|
|
|
:param scheduler: if not `none`, the name of the expert LR scheduler
|
|
|
:param num_warmup_steps: the number of warmup steps for LR schedule
|
|
|
- :param num_training_steps: the total number of steps for LR schedule
|
|
|
+ :param num_total_steps: the total number of steps for LR schedule
|
|
|
+ :param clip_grad_norm: maximum gradient norm used for clipping
|
|
|
|
|
|
:param no_dht: if specified, the server will not be attached to a dht
|
|
|
:param initial_peers: a list of peers that will introduce this node to the dht,\
|
|
@@ -105,6 +107,7 @@ class Server(threading.Thread):
|
|
|
for each BatchTensorProto in ExpertBackend for the respective experts.
|
|
|
|
|
|
:param start: if True, starts server right away and returns when server is ready for requests
|
|
|
+ :param stats_report_interval: interval between two reports of batch processing performance statistics
|
|
|
"""
|
|
|
if len(kwargs) != 0:
|
|
|
logger.info("Ignored kwargs:", kwargs)
|
|
@@ -165,14 +168,15 @@ class Server(threading.Thread):
|
|
|
optimizer=optim_cls(expert.parameters()),
|
|
|
scheduler=scheduler,
|
|
|
num_warmup_steps=num_warmup_steps,
|
|
|
- num_training_steps=num_training_steps,
|
|
|
+ num_total_steps=num_total_steps,
|
|
|
+ clip_grad_norm=clip_grad_norm,
|
|
|
max_batch_size=max_batch_size)
|
|
|
|
|
|
if checkpoint_dir is not None:
|
|
|
load_experts(experts, checkpoint_dir)
|
|
|
|
|
|
return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
|
|
|
- checkpoint_dir=checkpoint_dir, start=start)
|
|
|
+ checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)
|
|
|
|
|
|
def run(self):
|
|
|
"""
|