123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- """Code for serving bloom blocks via hivemind-server"""
- import multiprocessing as mp
- import os
- import threading
- from concurrent.futures import Future
- from dataclasses import dataclass, field
- from queue import Empty, PriorityQueue
- from typing import Optional, Sequence, Tuple
- import torch
- from hivemind import use_hivemind_log_handler
- from hivemind.moe.server.module_backend import ModuleBackend
- from hivemind.moe.server.task_pool import Task, TaskPool
- from hivemind.utils import InvalidStateError, MPFuture, get_logger
- from src.bloom.from_pretrained import BloomBlock
- from src.server.cache import MemoryCache
- from src.server.task_broker import SimpleBroker, TaskBrokerBase
- use_hivemind_log_handler("in_root_logger")
- logger = get_logger(__file__)
- MAX_LENGTH = 2048
- @dataclass(order=True)
- class PrioritizedTask:
- value: int
- task: Task = field(compare=False)
- class PrioritizedTaskPool(TaskPool):
- def __init__(self, *args, broker: TaskBrokerBase = SimpleBroker(), **kwargs):
- super().__init__(*args, **kwargs)
- self.broker = broker
- self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
- self.priority_queue = PriorityQueue(maxsize=self.tasks.maxsize)
- def submit_task(self, *args: torch.Tensor, dust: float = 0.0) -> Future:
- f = super().submit_task(*args)
- self.dust_queue.put(dust)
- return f
- def _priortize_tasks(self):
- """Infinite loop prioritizing incoming tasks"""
- while True:
- task = self.tasks.get(block=True)
- dust = self.dust_queue.get(block=True)
- self.priority_queue.put(PrioritizedTask(-self.broker(task, dust), task), block=True)
- def run(self, *args, **kwargs):
- torch.set_num_threads(1)
- logger.info(f"{self.name} starting, pid={os.getpid()}")
- pending_batches = {} # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
- output_thread = threading.Thread(
- target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
- )
- priority_thread = threading.Thread(
- target=self._priortize_tasks, args=[], name=f"{self.name}_priority", daemon=True
- )
- try:
- output_thread.start()
- priority_thread.start()
- self._pool_input_loop(pending_batches, *args, **kwargs)
- except KeyboardInterrupt:
- logger.debug("Caught KeyboardInterrupt, shutting down")
- finally:
- output_thread.join()
- priority_thread.join()
- # TODO: this is a copy-paste of the original method, except that we use different queue
- def iterate_minibatches(self, *args, **kwargs):
- """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
- batch = []
- total_size = 0
- while True:
- if total_size >= self.min_batch_size and self.priority_queue.empty():
- yield batch
- batch = []
- total_size = 0
- try:
- logger.debug(f"{self.name} getting next task")
- task = self.priority_queue.get(timeout=self.timeout)
- except Empty:
- logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
- continue
- task_size = self.get_task_size(task)
- if total_size + task_size > self.max_batch_size:
- yield batch
- batch = []
- total_size = 0
- try:
- if task.future.set_running_or_notify_cancel():
- batch.append(task)
- total_size += task_size
- except InvalidStateError as e:
- logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
- class InferenceTaskPool(TaskPool):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
- def iterate_minibatches(self, *args, **kwargs):
- """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
- while True:
- try:
- logger.debug(f"{self.name} getting next task")
- task = self.tasks.get(timeout=self.timeout)
- except Empty:
- logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
- continue
- try:
- if task.future.set_running_or_notify_cancel():
- yield [task]
- except InvalidStateError as e:
- logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
- class TransformerBackend(ModuleBackend):
- """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
- def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
- super().__init__(*args, **kwargs)
- assert isinstance(self.module, BloomBlock)
- self.memory_cache = memory_cache
- for name, param in self.module.named_parameters():
- assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
- for name, buf in self.module.named_buffers():
- assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
- self.inference_pool = InferenceTaskPool(
- self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
- )
- self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
- def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
- with torch.inference_mode():
- attention_cache_handle = int(cache_metadata[0, 0].item())
- prefix_length = int(cache_metadata[0, 1].item())
- hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
- assert (
- hidden_states.ndim == 3
- ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
- with self.memory_cache.use_cache(attention_cache_handle) as cache:
- assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
- layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
- print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
- hidden_states, (new_k, new_v) = self.module.forward(
- hidden_states, layer_past=layer_past, use_cache=True
- )
- # todo remove these asserts once we pass all tests
- new_length = new_v.shape[1]
- assert new_length > prefix_length
- assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
- assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
- assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
- assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
- assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
- cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
- cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
- return (hidden_states,)
- def get_pools(self) -> Sequence[TaskPool]:
- return self.forward_pool, self.backward_pool, self.inference_pool
|