|
@@ -173,6 +173,12 @@ class Server:
|
|
|
self.quant_type = quant_type
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|
|
+ if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4:
|
|
|
+ logger.warning(
|
|
|
+ "LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). "
|
|
|
+ "If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype"
|
|
|
+ )
|
|
|
+
|
|
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|