Parcourir la source

Use bnb==0.40.0.post4 to fix bias bug, use bfloat16 by default

Aleksandr Borzunov il y a 2 ans
Parent
commit
b0d55ee655

+ 1 - 1
setup.cfg

@@ -32,7 +32,7 @@ packages = find:
 python_requires = >=3.7
 install_requires =
     torch>=1.12
-    bitsandbytes==0.40.0.post3
+    bitsandbytes==0.40.0.post4
     accelerate>=0.16.0,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3

+ 2 - 3
src/petals/client/from_pretrained.py

@@ -29,9 +29,8 @@ class FromPretrainedMixin:
         if low_cpu_mem_usage is None:
             low_cpu_mem_usage = True
         if torch_dtype is None:
-            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
-            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
-            torch_dtype = "auto"
+            # torch_dtype=None gives torch.float32 in transformers>=4.26.0
+            torch_dtype = torch.bfloat16
 
         with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
             return super().from_pretrained(

+ 0 - 2
src/petals/server/block_utils.py

@@ -11,8 +11,6 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
     if dtype not in ("auto", None):
         return dtype
-    if config.torch_dtype not in ("auto", None):
-        return config.torch_dtype
     return torch.bfloat16
 
 

+ 6 - 0
src/petals/server/server.py

@@ -173,6 +173,12 @@ class Server:
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
+        if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4:
+            logger.warning(
+                "LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). "
+                "If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype"
+            )
+
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8