|
@@ -27,7 +27,13 @@ class TransformerBackend(ModuleBackend):
|
|
|
_peft_module = None
|
|
|
|
|
|
def __init__(
|
|
|
- self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
|
|
|
+ self,
|
|
|
+ *args,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ memory_cache: MemoryCache,
|
|
|
+ backend_dtype: torch.dtype,
|
|
|
+ max_chunk_size_bytes: int,
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
import petals.utils.peft as _peft_module
|
|
|
|
|
@@ -37,6 +43,8 @@ class TransformerBackend(ModuleBackend):
|
|
|
assert isinstance(self.module, TensorParallel)
|
|
|
self.config = config
|
|
|
self.memory_cache = memory_cache
|
|
|
+ self.max_chunk_size_bytes = max_chunk_size_bytes
|
|
|
+
|
|
|
for name, param in self.module.named_parameters():
|
|
|
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
|
|
|
for name, buf in self.module.named_buffers():
|
|
@@ -55,6 +63,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
)
|
|
|
|
|
|
self.dtype = backend_dtype
|
|
|
+ self.dtype_bytes = torch.finfo(self.dtype).bits // 8
|
|
|
self.shard_num_heads = []
|
|
|
for shard in self.module.module_shards:
|
|
|
for submodule in shard.modules():
|
|
@@ -105,14 +114,40 @@ class TransformerBackend(ModuleBackend):
|
|
|
inference_info: InferenceMetadata,
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
+ seq_len = hidden_states.shape[1]
|
|
|
+
|
|
|
with self.memory_cache.use_cache(
|
|
|
*inference_info.cache_handles
|
|
|
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
|
|
|
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
+
|
|
|
+ # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
|
|
|
+ # reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
|
|
|
+ # is at least 4-6x less than `autograd_memory`.
|
|
|
+ max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
|
|
|
+ output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
|
|
|
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
|
|
- hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
|
|
+ for offset in range(0, seq_len, max_chunk_length):
|
|
|
+ hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
|
|
|
+ output_hidden_states_chunk, new_kvs = self.module.forward(
|
|
|
+ hidden_states_chunk, layer_past=layer_past, use_cache=True
|
|
|
+ )
|
|
|
+ if seq_len > max_chunk_length:
|
|
|
+ output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
|
|
|
+ else:
|
|
|
+ output_hidden_states = output_hidden_states_chunk # saves one memcopy
|
|
|
+ layer_past = new_kvs
|
|
|
+
|
|
|
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
|
|
- return (hidden_states,)
|
|
|
+ return (output_hidden_states,)
|
|
|
+
|
|
|
+ def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:
|
|
|
+ # We assume that attention logit matrices are the main thing that consumes memory, given that
|
|
|
+ # the model uses multi-query attention
|
|
|
+ batch_size, seq_length, hidden_size = hidden_states.shape
|
|
|
+ worst_case_length = inference_info.prefix_length + seq_length
|
|
|
+ attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length
|
|
|
+ return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)
|
|
|
|
|
|
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
|
|
|
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|