Browse Source

Allow to disable chunked forward (#176)

Alexander Borzunov 2 years ago
parent
commit
6948a0c5ee
2 changed files with 7 additions and 3 deletions
  1. 5 2
      src/petals/bloom/modeling_utils.py
  2. 2 1
      src/petals/client/remote_model.py

+ 5 - 2
src/petals/bloom/modeling_utils.py

@@ -45,8 +45,11 @@ class LMHead(nn.Module):
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
 
-        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
-        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
+        if (
+            self.chunk_size is not None
+            and word_embeddings.dtype in [torch.float16, torch.bfloat16]
+            and word_embeddings.device.type == "cpu"
+        ):
             lm_logits = self.chunked_forward(hidden_states)
         else:
             # Switch dtype in case word_embeddings are fp16/bf16

+ 2 - 1
src/petals/client/remote_model.py

@@ -34,7 +34,8 @@ class DistributedBloomConfig(BloomConfig):
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     daemon_startup_timeout: int = 30
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
+    chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000
+    # Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16.
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
     request_timeout: int = 30  # a number of seconds for waiting result from each node