5
0

backend.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """Code for serving bloom blocks via hivemind-server"""
  2. import multiprocessing as mp
  3. import os
  4. import threading
  5. from concurrent.futures import Future
  6. from dataclasses import dataclass, field
  7. from queue import Empty, PriorityQueue
  8. from typing import Any, Dict, Optional, Sequence, Tuple
  9. import torch
  10. from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
  11. from hivemind.moe.server.module_backend import ModuleBackend
  12. from hivemind.moe.server.task_pool import Task, TaskPool
  13. from hivemind.utils import InvalidStateError, get_logger
  14. from src.bloom.from_pretrained import BloomBlock
  15. from src.server.cache import MemoryCache
  16. from src.utils.misc import is_dummy
  17. use_hivemind_log_handler("in_root_logger")
  18. logger = get_logger(__file__)
  19. @dataclass(order=True)
  20. class PrioritizedTask:
  21. priority: float
  22. task: Task = field(compare=False)
  23. class PrioritizedTaskPool(TaskPool):
  24. def __init__(self, *args, **kwargs):
  25. super().__init__(*args, **kwargs)
  26. assert self.min_batch_size == 1, "PriorityTaskPool supports no batching"
  27. self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
  28. self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
  29. self.undispatched_task_priorities = mp.SimpleQueue()
  30. self._timestamp = mp.Value(ctypes.c_double, 1.0)
  31. @property
  32. def priority(self):
  33. return (-self._priority.value, -self._timestamp.value)
  34. @priority.setter
  35. def priority(self, priority_tuple: Sequence[float]):
  36. assert len(priority_tuple) == 2, "pool priority must be a tuple of (priority, time_submitted)"
  37. self._priority.value, self._timestamp.value = map(float, priority_tuple)
  38. def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
  39. f = super().submit_task(*args)
  40. self.priority_queue.put(priority)
  41. self.undispatched_task_priorities.put(priority)
  42. # TODO use a single queue here
  43. return f
  44. def _priortize_tasks(self):
  45. """Infinite loop prioritizing incoming tasks"""
  46. while True:
  47. task = self.tasks.get(block=True)
  48. priority = self.priority_queue.get(block=True)
  49. self.prioritized_task_queue.put(PrioritizedTask(priority, task), block=True)
  50. def run(self, *args, **kwargs):
  51. torch.set_num_threads(1)
  52. logger.info(f"{self.name} starting, pid={os.getpid()}")
  53. pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
  54. output_thread = threading.Thread(
  55. target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
  56. )
  57. priority_thread = threading.Thread(
  58. target=self._priortize_tasks, args=[], name=f"{self.name}_priority", daemon=True
  59. )
  60. try:
  61. output_thread.start()
  62. priority_thread.start()
  63. self._pool_input_loop(pending_batches, *args, **kwargs)
  64. except KeyboardInterrupt:
  65. logger.debug("Caught KeyboardInterrupt, shutting down")
  66. finally:
  67. output_thread.join()
  68. priority_thread.join()
  69. def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs):
  70. """Infinite loop: aggregate tasks into batches and send them to runtime"""
  71. prev_num_tasks = 0 # number of tasks currently in shared buffer
  72. batch_index = max(pending_batches.keys(), default=0)
  73. batch_iterator = self.iterate_minibatches(*args, **kwargs)
  74. while True:
  75. # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
  76. # assumes that tasks are processed in the same order as they are created
  77. for skip_i in range(prev_num_tasks):
  78. dispatched_task_timestamp = self.undispatched_task_timestamps.get()
  79. dispatched_task_priority = self.undispatched_task_priorities.get()
  80. if skip_i == prev_num_tasks - 1:
  81. self.priority = (dispatched_task_priority, dispatched_task_timestamp)
  82. logger.debug(f"{self.name} getting next batch")
  83. batch_tasks = next(batch_iterator)
  84. # save batch futures, _output_loop will deliver on them later
  85. pending_batches[batch_index] = batch_tasks
  86. logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
  87. # find or create shared arrays for current batch size
  88. batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
  89. batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
  90. logger.debug(f"{self.name}, batch {batch_index}: sending to runtime")
  91. self.batch_sender.send((batch_index, batch_inputs))
  92. logger.debug(f"{self.name}, batch {batch_index}: sent to runtime")
  93. prev_num_tasks = len(batch_tasks)
  94. batch_index += 1
  95. # TODO: this is a copy-paste of the original method, except that we use different queue
  96. def iterate_minibatches(self, *args, **kwargs):
  97. """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
  98. while True:
  99. try:
  100. logger.debug(f"{self.name} getting next task")
  101. task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
  102. except Empty:
  103. logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
  104. continue
  105. try:
  106. if task.task.future.set_running_or_notify_cancel():
  107. yield [task]
  108. except InvalidStateError as e:
  109. logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
  110. class TransformerBackend(ModuleBackend):
  111. """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
  112. def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
  113. super().__init__(*args, **kwargs)
  114. assert isinstance(self.module, BloomBlock)
  115. self.memory_cache = memory_cache
  116. for name, param in self.module.named_parameters():
  117. assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
  118. for name, buf in self.module.named_buffers():
  119. assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
  120. self.inference_pool = PrioritizedTaskPool(
  121. self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
  122. )
  123. self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
  124. self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
  125. self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
  126. self.inference_schema = (
  127. (
  128. *self.args_schema,
  129. BatchTensorDescriptor((), dtype=self.dtype),
  130. BatchTensorDescriptor((), dtype=torch.int64),
  131. ),
  132. self.kwargs_schema,
  133. )
  134. def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
  135. with torch.inference_mode():
  136. attention_cache_handle = int(cache_metadata[0, 0].item())
  137. prefix_length = int(cache_metadata[0, 1].item())
  138. (hidden_states, hypo_ids) = inputs
  139. assert (
  140. hidden_states.ndim == 3
  141. ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
  142. with self.memory_cache.use_cache(attention_cache_handle) as cache:
  143. assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
  144. if not is_dummy(hypo_ids):
  145. cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
  146. layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
  147. print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
  148. hidden_states, (new_k, new_v) = self.module.forward(
  149. hidden_states, layer_past=layer_past, use_cache=True
  150. )
  151. # todo remove these asserts once we pass all tests
  152. new_length = new_v.shape[1]
  153. assert new_length > prefix_length
  154. assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
  155. assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
  156. assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
  157. cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
  158. cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
  159. return (hidden_states,)
  160. def get_pools(self) -> Sequence[TaskPool]:
  161. return self.forward_pool, self.backward_pool, self.inference_pool
  162. def get_info(self) -> Dict[str, Any]:
  163. """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
  164. return dict(super().get_info(), inference_schema=self.inference_schema)