Aleksandr Borzunov 2 роки тому
батько
коміт
f53c581690
1 змінених файлів з 7 додано та 3 видалено
  1. 7 3
      src/petals/server/server.py

+ 7 - 3
src/petals/server/server.py

@@ -173,10 +173,14 @@ class Server:
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
-        if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4:
+        if (
+            self.block_config.torch_dtype == torch.float16  # If weights are in float16
+            and torch_dtype == torch.bfloat16  # but we load them in bfloat16
+            and quant_type != QuantType.NF4
+        ):
             logger.warning(
-                "LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). "
-                "If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype"
+                "LLaMA is loaded in bfloat16 for compatibility with NF4 servers holding Guanaco adapters. "
+                "If you want to run it in float16, use --torch_dtype float16"
             )
 
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens