Browse Source

Disable chunked_forward() on AVX512 CPUs (#179)

Alexander Borzunov 2 years ago
parent
commit
55698381d0
3 changed files with 34 additions and 12 deletions
  1. 1 0
      setup.cfg
  2. 25 8
      src/petals/bloom/modeling_utils.py
  3. 8 4
      src/petals/client/remote_model.py

+ 1 - 0
setup.cfg

@@ -42,6 +42,7 @@ install_requires =
     tensor_parallel==1.0.23
     tensor_parallel==1.0.23
     humanfriendly
     humanfriendly
     async-timeout>=4.0.2
     async-timeout>=4.0.2
+    cpufeature>=0.2.0
 
 
 [options.extras_require]
 [options.extras_require]
 dev =
 dev =

+ 25 - 8
src/petals/bloom/modeling_utils.py

@@ -4,9 +4,11 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
 See commit history for authorship.
 See commit history for authorship.
 """
 """
 
 
+import psutil
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
 import torch.utils.checkpoint
 import torch.utils.checkpoint
+from cpufeature import CPUFeature
 from hivemind import get_logger
 from hivemind import get_logger
 from torch import nn
 from torch import nn
 from transformers import BloomConfig
 from transformers import BloomConfig
@@ -24,7 +26,14 @@ class LMHead(nn.Module):
     def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
     def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
         super().__init__()
         super().__init__()
         self.word_embeddings = word_embeddings
         self.word_embeddings = word_embeddings
-        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+        self.use_chunked_forward = config.use_chunked_forward
+        if self.use_chunked_forward == "auto":
+            # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
+            # Otherwise, it's ~8x slower.
+            self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
+        self.chunked_forward_step = config.chunked_forward_step
+        self._bf16_warning_shown = False
 
 
     @property
     @property
     def in_features(self) -> int:
     def in_features(self) -> int:
@@ -46,9 +55,9 @@ class LMHead(nn.Module):
         word_embeddings = self.word_embeddings.weight
         word_embeddings = self.word_embeddings.weight
 
 
         if (
         if (
-            self.chunk_size is not None
-            and word_embeddings.dtype in [torch.float16, torch.bfloat16]
+            word_embeddings.dtype in [torch.float16, torch.bfloat16]
             and word_embeddings.device.type == "cpu"
             and word_embeddings.device.type == "cpu"
+            and self.use_chunked_forward
         ):
         ):
             lm_logits = self.chunked_forward(hidden_states)
             lm_logits = self.chunked_forward(hidden_states)
         else:
         else:
@@ -59,9 +68,17 @@ class LMHead(nn.Module):
 
 
     def chunked_forward(self, hidden_states):
     def chunked_forward(self, hidden_states):
         """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
         """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
-        chunk_size: provides trade-off between efficiency and extra memory consumption.
+        chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
         """
         """
-        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+        assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
+
+        if not self._bf16_warning_shown:
+            if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
+                logger.warning(
+                    "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
+                    "Consider loading the model with torch_dtype='float32'"
+                )
+            self._bf16_warning_shown = True
 
 
         word_embeddings = self.word_embeddings.weight
         word_embeddings = self.word_embeddings.weight
         num_embeddings = self.word_embeddings.num_embeddings
         num_embeddings = self.word_embeddings.num_embeddings
@@ -69,7 +86,7 @@ class LMHead(nn.Module):
         hidden_states = hidden_states.float()
         hidden_states = hidden_states.float()
         output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
         output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
 
 
-        for i in range(0, num_embeddings, self.chunk_size):
-            chunk = word_embeddings[i : i + self.chunk_size].float()
-            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
+        for i in range(0, num_embeddings, self.chunked_forward_step):
+            chunk = word_embeddings[i : i + self.chunked_forward_step].float()
+            output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
         return output
         return output

+ 8 - 4
src/petals/client/remote_model.py

@@ -1,6 +1,6 @@
 import os
 import os
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import List, Optional
+from typing import List, Optional, Union
 
 
 import hivemind
 import hivemind
 import torch
 import torch
@@ -34,11 +34,15 @@ class DistributedBloomConfig(BloomConfig):
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     daemon_startup_timeout: int = 30
     daemon_startup_timeout: int = 30
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000
-    # Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16.
+    request_timeout: int = 30  # a number of seconds for waiting result from each node
+
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
-    request_timeout: int = 30  # a number of seconds for waiting result from each node
+
+    # This settings matter for running the client with dtype bfloat16 on CPU.
+    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
+    use_chunked_forward: Union[str, bool] = "auto"
+    chunked_forward_step: int = 16384
 
 
 
 
 original_register_parameter = nn.Module.register_parameter
 original_register_parameter = nn.Module.register_parameter