|
@@ -3,10 +3,11 @@ from __future__ import annotations
|
|
|
|
|
|
from collections import Counter
|
|
|
from itertools import chain
|
|
|
-from typing import Any, Dict, Sequence, Tuple
|
|
|
+from typing import Any, Dict, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
from hivemind import BatchTensorDescriptor, TensorDescriptor
|
|
|
+from hivemind.moe.expert_uid import ExpertUID
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
from hivemind.utils import get_logger
|
|
|
from tensor_parallel import TensorParallel
|
|
@@ -15,7 +16,7 @@ from transformers import BloomConfig
|
|
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
|
|
from petals.data_structures import InferenceMetadata
|
|
|
-from petals.server.memory_cache import MemoryCache
|
|
|
+from petals.server.memory_cache import Handle, MemoryCache
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
from petals.utils.misc import is_dummy
|
|
|
|
|
@@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
device = self.module.devices[self.module.output_device_index]
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
|
|
- )
|
|
|
+ ) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
|
|
self.forward_pool = PrioritizedTaskPool(
|
|
|
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
|
|
)
|
|
@@ -69,6 +70,21 @@ class TransformerBackend(ModuleBackend):
|
|
|
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
|
|
|
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
|
|
|
+ """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
|
|
|
+ assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
|
|
|
+ first_pool = next(iter(backends.values())).inference_pool
|
|
|
+ merged_pool = PrioritizedTaskPool(
|
|
|
+ _MergedInferenceStep(backends),
|
|
|
+ max_batch_size=first_pool.max_batch_size,
|
|
|
+ device=first_pool.device,
|
|
|
+ name=f"merged_inference",
|
|
|
+ )
|
|
|
+ for backend in backends.values():
|
|
|
+ assert not backend.inference_pool.is_alive()
|
|
|
+ backend.inference_pool = merged_pool
|
|
|
+
|
|
|
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
|
|
|
"""Create tensor descriptors for attention cache tensors used during inference_step"""
|
|
|
head_dim = self.config.hidden_size // self.config.n_head
|
|
@@ -79,22 +95,20 @@ class TransformerBackend(ModuleBackend):
|
|
|
cache_tensors.extend((keys, values))
|
|
|
return cache_tensors
|
|
|
|
|
|
+ @torch.inference_mode()
|
|
|
def inference_step(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
hypo_ids: torch.LongTensor,
|
|
|
inference_info: InferenceMetadata,
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
- with torch.inference_mode():
|
|
|
- assert (
|
|
|
- hidden_states.ndim == 3
|
|
|
- ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
- with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
- self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
- layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
|
|
- hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
|
|
- self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
|
|
- return (hidden_states,)
|
|
|
+ assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
+ with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
+ self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
+ layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
|
|
+ hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
|
|
+ self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
|
|
+ return (hidden_states,)
|
|
|
|
|
|
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
|
|
|
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|
|
@@ -139,3 +153,24 @@ class TransformerBackend(ModuleBackend):
|
|
|
dummy = torch.tensor([])
|
|
|
for p in self.module.parameters():
|
|
|
p.data = dummy
|
|
|
+
|
|
|
+
|
|
|
+class _MergedInferenceStep:
|
|
|
+ def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
|
|
|
+ self.backends = backends
|
|
|
+
|
|
|
+ def __call__(
|
|
|
+ self,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ hypo_ids: torch.LongTensor,
|
|
|
+ inference_infos: Sequence[InferenceMetadata],
|
|
|
+ *optional_prompts: Optional[torch.Tensor],
|
|
|
+ ) -> Tuple[torch.Tensor, ...]:
|
|
|
+ assert len(inference_infos) == len(
|
|
|
+ optional_prompts
|
|
|
+ ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
|
|
|
+ for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
|
|
|
+ if optional_prompt is not None:
|
|
|
+ hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
|
|
|
+ (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
|
|
|
+ return (hidden_states,)
|