backend.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """Code for serving bloom blocks via hivemind-server"""
  2. from __future__ import annotations
  3. from collections import Counter
  4. from itertools import chain
  5. from typing import Any, Dict, Optional, Sequence, Tuple
  6. import torch
  7. from hivemind import BatchTensorDescriptor, TensorDescriptor
  8. from hivemind.moe.expert_uid import ExpertUID
  9. from hivemind.moe.server.module_backend import ModuleBackend
  10. from hivemind.utils import get_logger
  11. from tensor_parallel import TensorParallel
  12. from tensor_parallel.tensor_parallel import PerDeviceTensors
  13. from transformers import BloomConfig
  14. from transformers.models.bloom.modeling_bloom import BloomAttention
  15. from petals.data_structures import InferenceMetadata
  16. from petals.server.memory_cache import Handle, MemoryCache
  17. from petals.server.task_pool import PrioritizedTaskPool
  18. from petals.utils.misc import is_dummy
  19. logger = get_logger(__file__)
  20. class TransformerBackend(ModuleBackend):
  21. """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
  22. def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
  23. super().__init__(*args, **kwargs)
  24. assert isinstance(self.module, TensorParallel)
  25. self.config = config
  26. self.memory_cache = memory_cache
  27. for name, param in self.module.named_parameters():
  28. assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
  29. for name, buf in self.module.named_buffers():
  30. assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
  31. max_batch_size = self.forward_pool.max_batch_size
  32. device = self.module.devices[self.module.output_device_index]
  33. self.inference_pool = PrioritizedTaskPool(
  34. self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
  35. ) # note: inference_pools may be merged later, see merge_inference_pools_inplace
  36. self.forward_pool = PrioritizedTaskPool(
  37. self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
  38. )
  39. self.backward_pool = PrioritizedTaskPool(
  40. self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
  41. )
  42. assert backend_dtype is not None
  43. self.dtype = backend_dtype
  44. self.shard_num_heads = []
  45. for shard in self.module.module_shards:
  46. for submodule in shard.modules():
  47. if isinstance(submodule, BloomAttention):
  48. self.shard_num_heads.append(submodule.num_heads)
  49. assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
  50. self.inference_schema = (
  51. (
  52. *self.args_schema,
  53. BatchTensorDescriptor((), dtype=self.dtype),
  54. BatchTensorDescriptor((), dtype=torch.int64),
  55. ),
  56. self.kwargs_schema,
  57. )
  58. self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
  59. for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
  60. self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
  61. def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
  62. """Create tensor descriptors for attention cache tensors used during inference_step"""
  63. head_dim = self.config.hidden_size // self.config.n_head
  64. cache_tensors = []
  65. for device, num_heads in zip(self.module.devices, self.shard_num_heads):
  66. keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
  67. values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
  68. cache_tensors.extend((keys, values))
  69. return cache_tensors
  70. @torch.inference_mode()
  71. def inference_step(
  72. self,
  73. hidden_states: torch.Tensor,
  74. hypo_ids: torch.LongTensor,
  75. inference_info: InferenceMetadata,
  76. ) -> Tuple[torch.Tensor, ...]:
  77. assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
  78. with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
  79. self._reorder_cache_inplace(cache_tensors, hypo_ids)
  80. layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
  81. hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
  82. self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
  83. return (hidden_states,)
  84. def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
  85. """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
  86. if not is_dummy(hypo_ids):
  87. for cache_tensor in cache_tensors:
  88. cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids
  89. def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
  90. """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
  91. key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
  92. for i in range(len(key_cache)):
  93. key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
  94. value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim]
  95. layer_past = tuple(chain(*zip(key_cache, value_cache)))
  96. return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
  97. def _update_cache_inplace(
  98. self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
  99. ):
  100. """Writes new key/value tensors back into cache, works in-place"""
  101. _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
  102. for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
  103. new_key = new_key.view(*cache_key.shape[:3], new_length)
  104. cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
  105. for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
  106. new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
  107. cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
  108. def get_pools(self) -> Sequence[PrioritizedTaskPool]:
  109. return self.forward_pool, self.backward_pool, self.inference_pool
  110. def get_info(self) -> Dict[str, Any]:
  111. """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  112. return dict(super().get_info(), inference_schema=self.inference_schema)
  113. def shutdown(self):
  114. # Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
  115. self.forward_pool = self.backward_pool = self.inference_pool = None
  116. # Explicitly free the GPU memory. This is not necessary at the time this code is written,
  117. # but may help to avoid future issues when the module is not garbage-collected for some reasons
  118. dummy = torch.tensor([])
  119. for p in self.module.parameters():
  120. p.data = dummy
  121. def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
  122. """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
  123. assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
  124. first_pool = next(iter(backends.values())).inference_pool
  125. merged_pool = PrioritizedTaskPool(
  126. _MergedInferenceStep(backends),
  127. max_batch_size=first_pool.max_batch_size,
  128. device=first_pool.device,
  129. name=f"merged_inference",
  130. )
  131. for backend in backends.values():
  132. assert not backend.inference_pool.is_alive()
  133. backend.inference_pool = merged_pool
  134. class _MergedInferenceStep:
  135. def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
  136. self.backends = backends
  137. def __call__(
  138. self,
  139. hidden_states: torch.Tensor,
  140. hypo_ids: torch.LongTensor,
  141. inference_infos: Sequence[InferenceMetadata],
  142. *optional_prompts: Optional[torch.Tensor],
  143. ) -> Tuple[torch.Tensor, ...]:
  144. assert len(inference_infos) == len(
  145. optional_prompts
  146. ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
  147. for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
  148. if optional_prompt is not None:
  149. hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
  150. (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
  151. return (hidden_states,)