ソースを参照

Disable chunked_forward() on AVX512 CPUs (#179)

Alexander Borzunov 2 年 前
コミット
55698381d0
3 ファイル変更34 行追加12 行削除
  1. 1 0
      setup.cfg
  2. 25 8
      src/petals/bloom/modeling_utils.py
  3. 8 4
      src/petals/client/remote_model.py

+ 1 - 0
setup.cfg

@@ -42,6 +42,7 @@ install_requires =
     tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2
+    cpufeature>=0.2.0
 
 [options.extras_require]
 dev =

+ 25 - 8
src/petals/bloom/modeling_utils.py

@@ -4,9 +4,11 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
 See commit history for authorship.
 """
 
+import psutil
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
+from cpufeature import CPUFeature
 from hivemind import get_logger
 from torch import nn
 from transformers import BloomConfig
@@ -24,7 +26,14 @@ class LMHead(nn.Module):
     def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
         super().__init__()
         self.word_embeddings = word_embeddings
-        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+        self.use_chunked_forward = config.use_chunked_forward
+        if self.use_chunked_forward == "auto":
+            # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
+            # Otherwise, it's ~8x slower.
+            self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
+        self.chunked_forward_step = config.chunked_forward_step
+        self._bf16_warning_shown = False
 
     @property
     def in_features(self) -> int:
@@ -46,9 +55,9 @@ class LMHead(nn.Module):
         word_embeddings = self.word_embeddings.weight
 
         if (
-            self.chunk_size is not None
-            and word_embeddings.dtype in [torch.float16, torch.bfloat16]
+            word_embeddings.dtype in [torch.float16, torch.bfloat16]
             and word_embeddings.device.type == "cpu"
+            and self.use_chunked_forward
         ):
             lm_logits = self.chunked_forward(hidden_states)
         else:
@@ -59,9 +68,17 @@ class LMHead(nn.Module):
 
     def chunked_forward(self, hidden_states):
         """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
-        chunk_size: provides trade-off between efficiency and extra memory consumption.
+        chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
         """
-        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+        assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
+
+        if not self._bf16_warning_shown:
+            if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
+                logger.warning(
+                    "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
+                    "Consider loading the model with torch_dtype='float32'"
+                )
+            self._bf16_warning_shown = True
 
         word_embeddings = self.word_embeddings.weight
         num_embeddings = self.word_embeddings.num_embeddings
@@ -69,7 +86,7 @@ class LMHead(nn.Module):
         hidden_states = hidden_states.float()
         output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
 
-        for i in range(0, num_embeddings, self.chunk_size):
-            chunk = word_embeddings[i : i + self.chunk_size].float()
-            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
+        for i in range(0, num_embeddings, self.chunked_forward_step):
+            chunk = word_embeddings[i : i + self.chunked_forward_step].float()
+            output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
         return output

+ 8 - 4
src/petals/client/remote_model.py

@@ -1,6 +1,6 @@
 import os
 from contextlib import contextmanager
-from typing import List, Optional
+from typing import List, Optional, Union
 
 import hivemind
 import torch
@@ -34,11 +34,15 @@ class DistributedBloomConfig(BloomConfig):
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     daemon_startup_timeout: int = 30
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000
-    # Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16.
+    request_timeout: int = 30  # a number of seconds for waiting result from each node
+
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
-    request_timeout: int = 30  # a number of seconds for waiting result from each node
+
+    # This settings matter for running the client with dtype bfloat16 on CPU.
+    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
+    use_chunked_forward: Union[str, bool] = "auto"
+    chunked_forward_step: int = 16384
 
 
 original_register_parameter = nn.Module.register_parameter