|
@@ -17,7 +17,8 @@ from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_direct
|
|
|
from hivemind.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
|
|
|
from hivemind.server.expert_backend import ExpertBackend
|
|
|
-from hivemind.server.layers import name_to_block, name_to_input, schedule_name_to_scheduler
|
|
|
+from hivemind.server.layers import name_to_block, name_to_input
|
|
|
+from hivemind.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
|
|
|
from hivemind.server.runtime import Runtime
|
|
|
from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
|
|
|
from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger
|
|
@@ -72,8 +73,8 @@ class Server(threading.Thread):
|
|
|
expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
|
|
|
num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
|
|
|
device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
|
|
|
- compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
|
|
|
- **kwargs) -> Server:
|
|
|
+ compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None,
|
|
|
+ *, start: bool, **kwargs) -> Server:
|
|
|
"""
|
|
|
Instantiate a server with several identical experts. See argparse comments below for details
|
|
|
:param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
|
|
@@ -109,6 +110,9 @@ class Server(threading.Thread):
|
|
|
:param start: if True, starts server right away and returns when server is ready for requests
|
|
|
:param stats_report_interval: interval between two reports of batch processing performance statistics
|
|
|
"""
|
|
|
+ if custom_module_path is not None:
|
|
|
+ add_custom_models_from_file(custom_module_path)
|
|
|
+
|
|
|
if len(kwargs) != 0:
|
|
|
logger.info("Ignored kwargs:", kwargs)
|
|
|
assert expert_cls in name_to_block
|