|
@@ -1,12 +1,19 @@
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+from itertools import chain
|
|
|
from typing import Any, Dict, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import BatchTensorDescriptor
|
|
|
+from hivemind import BatchTensorDescriptor, TensorDescriptor
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
from hivemind.utils import get_logger
|
|
|
+from tensor_parallel import TensorParallel
|
|
|
+from tensor_parallel.tensor_parallel import PerDeviceTensors
|
|
|
+from transformers import BloomConfig
|
|
|
+from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
|
|
-from petals.bloom.block import WrappedBloomBlock
|
|
|
+from petals.data_structures import InferenceMetadata
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
from petals.utils.misc import is_dummy
|
|
@@ -17,9 +24,10 @@ logger = get_logger(__file__)
|
|
|
class TransformerBackend(ModuleBackend):
|
|
|
"""A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
|
|
|
|
|
|
- def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
|
|
|
+ def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
- assert isinstance(self.module, WrappedBloomBlock)
|
|
|
+ assert isinstance(self.module, TensorParallel)
|
|
|
+ self.config = config
|
|
|
self.memory_cache = memory_cache
|
|
|
for name, param in self.module.named_parameters():
|
|
|
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
@@ -27,18 +35,26 @@ class TransformerBackend(ModuleBackend):
|
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
max_batch_size = self.forward_pool.max_batch_size
|
|
|
+ device = self.module.devices[self.module.output_device_index]
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
- self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
|
|
|
+ self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
|
|
)
|
|
|
self.forward_pool = PrioritizedTaskPool(
|
|
|
- self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
|
|
|
+ self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
|
|
)
|
|
|
self.backward_pool = PrioritizedTaskPool(
|
|
|
- self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
|
|
|
+ self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
|
|
|
)
|
|
|
|
|
|
assert backend_dtype is not None
|
|
|
self.dtype = backend_dtype
|
|
|
+ self.shard_num_heads = []
|
|
|
+ for shard in self.module.module_shards:
|
|
|
+ for submodule in shard.modules():
|
|
|
+ if isinstance(submodule, BloomAttention):
|
|
|
+ self.shard_num_heads.append(submodule.num_heads)
|
|
|
+ assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
|
|
|
+
|
|
|
self.inference_schema = (
|
|
|
(
|
|
|
*self.args_schema,
|
|
@@ -48,44 +64,60 @@ class TransformerBackend(ModuleBackend):
|
|
|
self.kwargs_schema,
|
|
|
)
|
|
|
|
|
|
+ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
|
|
|
+ """Create tensor descriptors for attention cache tensors used during inference_step"""
|
|
|
+ head_dim = self.config.hidden_size // self.config.n_head
|
|
|
+ cache_tensors = []
|
|
|
+ for device, num_heads in zip(self.module.devices, self.shard_num_heads):
|
|
|
+ keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
|
|
|
+ values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
|
|
|
+ cache_tensors.extend((keys, values))
|
|
|
+ return cache_tensors
|
|
|
+
|
|
|
def inference_step(
|
|
|
- self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
|
|
|
+ self,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ hypo_ids: torch.LongTensor,
|
|
|
+ inference_info: InferenceMetadata,
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
- num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
|
|
|
with torch.inference_mode():
|
|
|
assert (
|
|
|
hidden_states.ndim == 3
|
|
|
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
- cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
|
|
|
-
|
|
|
- with self.memory_cache.use_cache(cache_handle) as cache:
|
|
|
- batch_size = cache.shape[2]
|
|
|
- max_length = cache.shape[-1] // (head_dim * num_heads)
|
|
|
- assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
|
|
|
- if not is_dummy(hypo_ids):
|
|
|
- assert hypo_ids.shape[0] == batch_size
|
|
|
- cache[rel_index, :, :] = cache[rel_index, :, hypo_ids] # in-place reorder cache by hypo ids
|
|
|
- key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
|
|
|
- value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
|
|
|
-
|
|
|
- key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
|
|
|
- value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim]
|
|
|
- logger.debug(
|
|
|
- f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
|
|
|
- )
|
|
|
- hidden_states, (new_key, new_value) = self.module.forward(
|
|
|
- hidden_states, layer_past=(key_past, value_past), use_cache=True
|
|
|
- )
|
|
|
- new_length = new_key.shape[-1]
|
|
|
- assert new_length > prefix_length
|
|
|
- assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
|
|
|
- assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
|
|
|
- new_key = new_key.view(batch_size, num_heads, head_dim, -1)
|
|
|
- new_value = new_value.view(batch_size, num_heads, -1, head_dim)
|
|
|
- key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
|
|
|
- value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
|
|
|
+ with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
+ self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
+ layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
|
|
+ hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
|
|
+ self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
|
|
return (hidden_states,)
|
|
|
|
|
|
+ def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
|
|
|
+ """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|
|
|
+ if not is_dummy(hypo_ids):
|
|
|
+ for cache_tensor in cache_tensors:
|
|
|
+ cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids
|
|
|
+
|
|
|
+ def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
|
|
|
+ """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
|
|
|
+ key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
|
|
|
+ for i in range(len(key_cache)):
|
|
|
+ key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
|
|
|
+ value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim]
|
|
|
+ layer_past = tuple(chain(*zip(key_cache, value_cache)))
|
|
|
+ return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
|
|
|
+
|
|
|
+ def _update_cache_inplace(
|
|
|
+ self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
|
|
|
+ ):
|
|
|
+ """Writes new key/value tensors back into cache, works in-place"""
|
|
|
+ _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
|
|
|
+ for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
|
|
|
+ new_key = new_key.view(*cache_key.shape[:3], new_length)
|
|
|
+ cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
|
|
|
+ for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
|
|
|
+ new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
|
|
|
+ cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
|
|
|
+
|
|
|
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
|
|
|
return self.forward_pool, self.backward_pool, self.inference_pool
|
|
|
|