|
@@ -14,6 +14,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
from src import declare_active_modules
|
|
|
from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
|
|
|
+from src.data_structures import UID_DELIMITER, CHAIN_DELIMITER
|
|
|
from src.server.backend import TransformerBackend
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
@@ -84,7 +85,7 @@ class Server(threading.Thread):
|
|
|
@classmethod
|
|
|
def create(
|
|
|
cls,
|
|
|
- prefix: str,
|
|
|
+ prefix: Optional[str],
|
|
|
converted_model_name_or_path: str,
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
@@ -108,6 +109,12 @@ class Server(threading.Thread):
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
if custom_module_path is not None:
|
|
|
add_custom_models_from_file(custom_module_path)
|
|
|
+ if prefix is None:
|
|
|
+ prefix = converted_model_name_or_path
|
|
|
+ assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix,\
|
|
|
+ f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " \
|
|
|
+ f"Please specify --prefix manually when starting a server"
|
|
|
+ logger.info(f"Automatic dht prefix: {prefix}")
|
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|