Browse Source

Determine block dtype in a unified manner (#325)

* Extract backend_dtype, remove duplicate DTYPE_MAP

* Use bfloat16 as the default dtype, resolve dtype in load_pretrained_block
Max Ryabinin 2 năm trước cách đây
mục cha
commit
c839173e57

+ 3 - 2
src/petals/bloom/from_pretrained.py

@@ -21,7 +21,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.utils import get_file_from_repo
 
 from petals.bloom.block import WrappedBloomBlock
-from petals.server.block_utils import get_block_size
+from petals.server.block_utils import get_block_size, resolve_block_dtype
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
 
 logger = get_logger(__name__)
@@ -41,6 +41,7 @@ def load_pretrained_block(
 ) -> WrappedBloomBlock:
     """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
     assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+    torch_dtype = resolve_block_dtype(config, torch_dtype)
 
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
@@ -66,7 +67,7 @@ def load_pretrained_block(
     for param_name, _ in block.named_parameters():
         assert param_name in state_dict, f"{param_name} not in state dict"
         param = state_dict[param_name]
-        if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
             param = param.to(torch_dtype)
         set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 

+ 1 - 3
src/petals/cli/convert_model.py

@@ -10,13 +10,11 @@ from huggingface_hub import HfApi, Repository
 from tqdm.auto import tqdm
 from transformers.models.bloom.modeling_bloom import BloomModel
 
-from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
+from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP
 from petals.client import DistributedBloomConfig
 
 logger = get_logger(__name__)
 
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
-
 
 def main():
     parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")

+ 6 - 7
src/petals/server/block_utils.py

@@ -7,14 +7,13 @@ from transformers import BloomConfig
 from petals.bloom.block import WrappedBloomBlock
 
 
-def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
+def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
-
-    if dtype == "auto" or dtype is None:
-        dtype = config.torch_dtype
-        if dtype == "auto" or dtype is None:
-            dtype = torch.float32
-    return dtype
+    if dtype not in ("auto", None):
+        return dtype
+    if config.torch_dtype not in ("auto", None):
+        return config.torch_dtype
+    return torch.bfloat16
 
 
 def get_block_size(

+ 6 - 6
src/petals/server/server.py

@@ -22,7 +22,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
-from petals.server.block_utils import get_block_size
+from petals.server.block_utils import get_block_size, resolve_block_dtype
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
@@ -151,7 +151,7 @@ class Server:
         if isinstance(torch_dtype, str):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        self.torch_dtype = torch_dtype
+        self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
 
         if tensor_parallel_devices is None:
             tensor_parallel_devices = (device,)
@@ -182,6 +182,7 @@ class Server:
         if attn_cache_size is None:
             # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
             attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
+
         self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
         logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
 
@@ -404,22 +405,21 @@ class ModuleContainer(threading.Thread):
                 )
                 block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
 
-                backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
                     config=block_config,
                     memory_cache=memory_cache,
-                    backend_dtype=backend_dtype,
+                    backend_dtype=torch_dtype,
                     args_schema=(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
                         ),
                     ),
                     kwargs_schema={},
                     outputs_schema=(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
                         ),
                     ),
                     min_batch_size=min_batch_size,

+ 17 - 0
tests/test_dtype.py

@@ -0,0 +1,17 @@
+import pytest
+import torch
+
+from petals.bloom.from_pretrained import load_pretrained_block
+from petals.client import DistributedBloomConfig
+from petals.server.block_utils import resolve_block_dtype
+from test_utils import MODEL_NAME
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
+def test_backend_dtype(torch_dtype):
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype)
+    backend_dtype = resolve_block_dtype(config, torch_dtype)
+    other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
+    assert backend_dtype == other_backend_dtype