|
@@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from src import declare_active_modules, BloomConfig
|
|
|
|
|
|
+from src import BloomConfig, declare_active_modules
|
|
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
from src.dht_utils import get_remote_module_infos
|
|
from src.dht_utils import get_remote_module_infos
|
|
@@ -98,7 +98,7 @@ class Server(threading.Thread):
|
|
cls,
|
|
cls,
|
|
prefix: Optional[str],
|
|
prefix: Optional[str],
|
|
converted_model_name_or_path: str,
|
|
converted_model_name_or_path: str,
|
|
- throughput: Union[float, Literal['auto', 'eval']],
|
|
|
|
|
|
+ throughput: Union[float, Literal["auto", "eval"]],
|
|
num_blocks: Optional[int] = None,
|
|
num_blocks: Optional[int] = None,
|
|
block_indices: Optional[str] = None,
|
|
block_indices: Optional[str] = None,
|
|
num_handlers: Optional[int] = None,
|
|
num_handlers: Optional[int] = None,
|
|
@@ -140,17 +140,15 @@ class Server(threading.Thread):
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
memory_cache = MemoryCache(device, cache_size_bytes)
|
|
memory_cache = MemoryCache(device, cache_size_bytes)
|
|
|
|
|
|
- assert isinstance(throughput, float) or throughput in ['auto', 'eval']
|
|
|
|
- if throughput in ['auto', 'eval']:
|
|
|
|
- throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
|
|
|
|
|
|
+ assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
|
+ if throughput in ["auto", "eval"]:
|
|
|
|
+ throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
|
|
|
|
|
|
if isinstance(torch_dtype, str):
|
|
if isinstance(torch_dtype, str):
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
|
- block_config = BloomConfig.from_pretrained(
|
|
|
|
- converted_model_name_or_path, use_auth_token=use_auth_token
|
|
|
|
- )
|
|
|
|
|
|
+ block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
|
|
|
if block_indices is not None:
|
|
if block_indices is not None:
|
|
try:
|
|
try:
|
|
@@ -288,7 +286,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
throughput: float,
|
|
throughput: float,
|
|
update_period: float = 30,
|
|
update_period: float = 30,
|
|
expiration: float,
|
|
expiration: float,
|
|
- **kwargs
|
|
|
|
|
|
+ **kwargs,
|
|
):
|
|
):
|
|
super().__init__(**kwargs)
|
|
super().__init__(**kwargs)
|
|
self.module_backends = module_backends
|
|
self.module_backends = module_backends
|