backend.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """Code for serving bloom blocks via hivemind-server"""
  2. import multiprocessing as mp
  3. import os
  4. import threading
  5. from concurrent.futures import Future
  6. from dataclasses import dataclass, field
  7. from queue import Empty, PriorityQueue
  8. from typing import Optional, Sequence, Tuple
  9. import torch
  10. from hivemind import use_hivemind_log_handler
  11. from hivemind.moe.server.module_backend import ModuleBackend
  12. from hivemind.moe.server.task_pool import Task, TaskPool
  13. from hivemind.utils import InvalidStateError, MPFuture, get_logger
  14. from src.bloom.from_pretrained import BloomBlock
  15. from src.server.cache import MemoryCache
  16. from src.server.task_broker import SimpleBroker, TaskBrokerBase
  17. use_hivemind_log_handler("in_root_logger")
  18. logger = get_logger(__file__)
  19. MAX_LENGTH = 2048
  20. @dataclass(order=True)
  21. class PrioritizedTask:
  22. value: int
  23. task: Task = field(compare=False)
  24. class PrioritizedTaskPool(TaskPool):
  25. def __init__(self, *args, broker: TaskBrokerBase = SimpleBroker(), **kwargs):
  26. super().__init__(*args, **kwargs)
  27. self.broker = broker
  28. self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
  29. self.priority_queue = PriorityQueue(maxsize=self.tasks.maxsize)
  30. def submit_task(self, *args: torch.Tensor, dust: float = 0.0) -> Future:
  31. f = super().submit_task(*args)
  32. self.dust_queue.put(dust)
  33. return f
  34. def _priortize_tasks(self):
  35. """Infinite loop prioritizing incoming tasks"""
  36. while True:
  37. task = self.tasks.get(block=True)
  38. dust = self.dust_queue.get(block=True)
  39. self.priority_queue.put(PrioritizedTask(-self.broker(task, dust), task), block=True)
  40. def run(self, *args, **kwargs):
  41. torch.set_num_threads(1)
  42. logger.info(f"{self.name} starting, pid={os.getpid()}")
  43. pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
  44. output_thread = threading.Thread(
  45. target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
  46. )
  47. priority_thread = threading.Thread(
  48. target=self._priortize_tasks, args=[], name=f"{self.name}_priority", daemon=True
  49. )
  50. try:
  51. output_thread.start()
  52. priority_thread.start()
  53. self._pool_input_loop(pending_batches, *args, **kwargs)
  54. except KeyboardInterrupt:
  55. logger.debug("Caught KeyboardInterrupt, shutting down")
  56. finally:
  57. output_thread.join()
  58. priority_thread.join()
  59. # TODO: this is a copy-paste of the original method, except that we use different queue
  60. def iterate_minibatches(self, *args, **kwargs):
  61. """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
  62. batch = []
  63. total_size = 0
  64. while True:
  65. if total_size >= self.min_batch_size and self.priority_queue.empty():
  66. yield batch
  67. batch = []
  68. total_size = 0
  69. try:
  70. logger.debug(f"{self.name} getting next task")
  71. task = self.priority_queue.get(timeout=self.timeout)
  72. except Empty:
  73. logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
  74. continue
  75. task_size = self.get_task_size(task)
  76. if total_size + task_size > self.max_batch_size:
  77. yield batch
  78. batch = []
  79. total_size = 0
  80. try:
  81. if task.future.set_running_or_notify_cancel():
  82. batch.append(task)
  83. total_size += task_size
  84. except InvalidStateError as e:
  85. logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
  86. class InferenceTaskPool(TaskPool):
  87. def __init__(self, *args, **kwargs):
  88. super().__init__(*args, **kwargs)
  89. assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
  90. def iterate_minibatches(self, *args, **kwargs):
  91. """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
  92. while True:
  93. try:
  94. logger.debug(f"{self.name} getting next task")
  95. task = self.tasks.get(timeout=self.timeout)
  96. except Empty:
  97. logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
  98. continue
  99. try:
  100. if task.future.set_running_or_notify_cancel():
  101. yield [task]
  102. except InvalidStateError as e:
  103. logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
  104. class TransformerBackend(ModuleBackend):
  105. """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
  106. def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
  107. super().__init__(*args, **kwargs)
  108. assert isinstance(self.module, BloomBlock)
  109. self.memory_cache = memory_cache
  110. for name, param in self.module.named_parameters():
  111. assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
  112. for name, buf in self.module.named_buffers():
  113. assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
  114. self.inference_pool = InferenceTaskPool(
  115. self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
  116. )
  117. self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
  118. def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  119. with torch.inference_mode():
  120. attention_cache_handle = int(cache_metadata[0, 0].item())
  121. prefix_length = int(cache_metadata[0, 1].item())
  122. hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
  123. assert (
  124. hidden_states.ndim == 3
  125. ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
  126. with self.memory_cache.use_cache(attention_cache_handle) as cache:
  127. assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
  128. layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
  129. print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
  130. hidden_states, (new_k, new_v) = self.module.forward(
  131. hidden_states, layer_past=layer_past, use_cache=True
  132. )
  133. # todo remove these asserts once we pass all tests
  134. new_length = new_v.shape[1]
  135. assert new_length > prefix_length
  136. assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
  137. assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
  138. assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
  139. assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
  140. assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
  141. cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
  142. cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
  143. return (hidden_states,)
  144. def get_pools(self) -> Sequence[TaskPool]:
  145. return self.forward_pool, self.backward_pool, self.inference_pool