|
@@ -12,12 +12,12 @@ from hivemind.moe.server.runtime import Runtime
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from src import BloomForCausalLM, DistributedBloomConfig
|
|
|
-from src.bloom.block import BloomBlock
|
|
|
+from src.bloom.from_pretrained import load_pretrained_block, DistributedBloomConfig, DTYPE_MAP
|
|
|
from src.server.backend import TransformerBackend
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
+
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
@@ -83,12 +83,13 @@ class Server(threading.Thread):
|
|
|
def create(
|
|
|
cls,
|
|
|
prefix: str,
|
|
|
- block_config: str,
|
|
|
+ converted_model_name_or_path: str,
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
num_handlers: Optional[int] = None,
|
|
|
min_batch_size: int = 1,
|
|
|
max_batch_size: int = 4096,
|
|
|
+ torch_dtype: str = 'auto',
|
|
|
cache_size_bytes: Optional[int] = None,
|
|
|
device: Union[str, torch.device] = None,
|
|
|
initial_peers: Sequence[str] = (),
|
|
@@ -112,6 +113,10 @@ class Server(threading.Thread):
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
memory_cache = MemoryCache(device, cache_size_bytes)
|
|
|
|
|
|
+ if isinstance(torch_dtype, str):
|
|
|
+ torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
+
|
|
|
if block_indices is not None:
|
|
|
try:
|
|
|
start, end = block_indices.split(":")
|
|
@@ -124,16 +129,13 @@ class Server(threading.Thread):
|
|
|
assert num_blocks is not None
|
|
|
block_indices = range(num_blocks) # TODO replace with proper load balancing
|
|
|
|
|
|
- ## TODO: the code below will load the entire model in RAM. Please replace with sliced model
|
|
|
- block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
|
|
- # model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
|
|
|
- ## /TODO
|
|
|
+ block_config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=True)
|
|
|
|
|
|
# initialize modules
|
|
|
blocks = {}
|
|
|
for block_index in block_indices:
|
|
|
module_uid = f"{prefix}.{block_index}"
|
|
|
- block = BloomBlock(block_config, layer_number=block_index)
|
|
|
+ block = load_pretrained_block(converted_model_name_or_path, block_index, block_config, torch_dtype)
|
|
|
for param in block.parameters():
|
|
|
param.requires_grad = False
|
|
|
|