瀏覽代碼

Split long sequences into chunks (#403)

This PR is designed to avoid OOMs when processing long sequences that happen due to the huge attention logits matrices.

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 年之前
父節點
當前提交
5af04524dd

+ 2 - 1
.github/workflows/run-tests.yaml

@@ -37,7 +37,8 @@ jobs:
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
-            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
+            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
+            --adapters $ADAPTER_NAME &> server1.log &
           SERVER1_PID=$!
 
           sleep 5  # wait for the first server to initialize DHT

+ 2 - 0
src/petals/cli/run_server.py

@@ -74,6 +74,8 @@ def main():
     parser.add_argument('--max_batch_size', type=int, default=None,
                         help='The total number of tokens in the same batch will not exceed this value. '
                              'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+    parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
+                        help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
     parser.add_argument('--attn_cache_tokens', type=int, default=None,
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
                              'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')

+ 38 - 3
src/petals/server/backend.py

@@ -27,7 +27,13 @@ class TransformerBackend(ModuleBackend):
     _peft_module = None
 
     def __init__(
-        self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
+        self,
+        *args,
+        config: PretrainedConfig,
+        memory_cache: MemoryCache,
+        backend_dtype: torch.dtype,
+        max_chunk_size_bytes: int,
+        **kwargs,
     ):
         import petals.utils.peft as _peft_module
 
@@ -37,6 +43,8 @@ class TransformerBackend(ModuleBackend):
         assert isinstance(self.module, TensorParallel)
         self.config = config
         self.memory_cache = memory_cache
+        self.max_chunk_size_bytes = max_chunk_size_bytes
+
         for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
         for name, buf in self.module.named_buffers():
@@ -55,6 +63,7 @@ class TransformerBackend(ModuleBackend):
         )
 
         self.dtype = backend_dtype
+        self.dtype_bytes = torch.finfo(self.dtype).bits // 8
         self.shard_num_heads = []
         for shard in self.module.module_shards:
             for submodule in shard.modules():
@@ -105,14 +114,40 @@ class TransformerBackend(ModuleBackend):
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+        seq_len = hidden_states.shape[1]
+
         with self.memory_cache.use_cache(
             *inference_info.cache_handles
         ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
+
+            # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
+            # reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
+            # is at least 4-6x less than `autograd_memory`.
+            max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
+            output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
-            hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+            for offset in range(0, seq_len, max_chunk_length):
+                hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
+                output_hidden_states_chunk, new_kvs = self.module.forward(
+                    hidden_states_chunk, layer_past=layer_past, use_cache=True
+                )
+                if seq_len > max_chunk_length:
+                    output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
+                else:
+                    output_hidden_states = output_hidden_states_chunk  # saves one memcopy
+                layer_past = new_kvs
+
             self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
-            return (hidden_states,)
+            return (output_hidden_states,)
+
+    def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:
+        # We assume that attention logit matrices are the main thing that consumes memory, given that
+        # the model uses multi-query attention
+        batch_size, seq_length, hidden_size = hidden_states.shape
+        worst_case_length = inference_info.prefix_length + seq_length
+        attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length
+        return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)
 
     def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
         """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""

+ 5 - 0
src/petals/server/server.py

@@ -58,6 +58,7 @@ class Server:
         inference_max_length: Optional[int] = None,
         min_batch_size: int = 1,
         max_batch_size: Optional[int] = None,
+        max_chunk_size_bytes: int = 256 * 1024 * 1024,
         attn_cache_tokens: Optional[int] = None,
         torch_dtype: str = "auto",
         revision: Optional[str] = None,
@@ -183,6 +184,7 @@ class Server:
             inference_max_length = 8192 if is_multiquery_attn else 2048
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
+        self.max_chunk_size_bytes = max_chunk_size_bytes
 
         # For attention cache in GPU or RAM
         if attn_cache_tokens is None:
@@ -312,6 +314,7 @@ class Server:
                 num_handlers=self.num_handlers,
                 min_batch_size=self.min_batch_size,
                 max_batch_size=self.max_batch_size,
+                max_chunk_size_bytes=self.max_chunk_size_bytes,
                 inference_max_length=self.inference_max_length,
                 torch_dtype=self.torch_dtype,
                 cache_dir=self.cache_dir,
@@ -412,6 +415,7 @@ class ModuleContainer(threading.Thread):
         block_indices: List[int],
         min_batch_size: int,
         max_batch_size: int,
+        max_chunk_size_bytes: int,
         torch_dtype: torch.dtype,
         cache_dir: str,
         max_disk_space: int,
@@ -477,6 +481,7 @@ class ModuleContainer(threading.Thread):
                     config=block_config,
                     memory_cache=memory_cache,
                     backend_dtype=torch_dtype,
+                    max_chunk_size_bytes=max_chunk_size_bytes,
                     args_schema=(
                         BatchTensorDescriptor(
                             1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression

+ 9 - 3
tests/test_full_model.py

@@ -28,7 +28,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.num_hidden_layers
 
-    test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
 
     with torch.inference_mode():
         parallel_outputs = model.forward(test_inputs).logits
@@ -43,8 +43,14 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
                 recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
 
             for t in range(embs.shape[1]):
-                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
-                if t == int(embs.shape[1] // 2) and pass_empty_tensors:
+                if t == 4:
+                    recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
+                elif 4 < t < 9:
+                    continue
+                else:
+                    recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+
+                if t == 2 and pass_empty_tensors:
                     recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
                     recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))