|
@@ -22,7 +22,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
|
|
from petals.server import block_selection
|
|
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
|
-from petals.server.block_utils import get_block_size
|
|
|
+from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
|
@@ -151,7 +151,7 @@ class Server:
|
|
|
if isinstance(torch_dtype, str):
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
- self.torch_dtype = torch_dtype
|
|
|
+ self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
|
|
|
|
|
|
if tensor_parallel_devices is None:
|
|
|
tensor_parallel_devices = (device,)
|
|
@@ -182,6 +182,7 @@ class Server:
|
|
|
if attn_cache_size is None:
|
|
|
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
|
|
|
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
|
|
|
+
|
|
|
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
|
|
|
@@ -404,22 +405,21 @@ class ModuleContainer(threading.Thread):
|
|
|
)
|
|
|
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
|
|
|
|
|
|
- backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
|
|
|
blocks[module_uid] = TransformerBackend(
|
|
|
module_uid,
|
|
|
block,
|
|
|
config=block_config,
|
|
|
memory_cache=memory_cache,
|
|
|
- backend_dtype=backend_dtype,
|
|
|
+ backend_dtype=torch_dtype,
|
|
|
args_schema=(
|
|
|
BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
|
|
),
|
|
|
),
|
|
|
kwargs_schema={},
|
|
|
outputs_schema=(
|
|
|
BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
|
|
),
|
|
|
),
|
|
|
min_batch_size=min_batch_size,
|