|
@@ -1,16 +1,46 @@
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
|
+from queue import Empty
|
|
from typing import Sequence, Tuple
|
|
from typing import Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
|
+from hivemind import use_hivemind_log_handler
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
|
|
+from hivemind.utils import InvalidStateError, get_logger
|
|
|
|
|
|
from src.bloom.from_pretrained import BloomBlock
|
|
from src.bloom.from_pretrained import BloomBlock
|
|
from src.server.cache import MemoryCache
|
|
from src.server.cache import MemoryCache
|
|
|
|
|
|
|
|
+use_hivemind_log_handler("in_root_logger")
|
|
|
|
+logger = get_logger(__file__)
|
|
|
|
+
|
|
MAX_LENGTH = 2048
|
|
MAX_LENGTH = 2048
|
|
|
|
|
|
|
|
|
|
|
|
+class InferenceTaskPool(TaskPool):
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
+
|
|
|
|
+ assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
|
|
|
|
+
|
|
|
|
+ def iterate_minibatches(self, *args, **kwargs):
|
|
|
|
+ """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
|
+
|
|
|
|
+ while True:
|
|
|
|
+ try:
|
|
|
|
+ logger.debug(f"{self.name} getting next task")
|
|
|
|
+ task = self.tasks.get(timeout=self.timeout)
|
|
|
|
+ except Empty:
|
|
|
|
+ logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ if task.future.set_running_or_notify_cancel():
|
|
|
|
+ yield [task]
|
|
|
|
+ except InvalidStateError as e:
|
|
|
|
+ logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
|
+
|
|
|
|
+
|
|
class TransformerBackend(ModuleBackend):
|
|
class TransformerBackend(ModuleBackend):
|
|
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
|
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
|
|
|
|
|
@@ -23,7 +53,9 @@ class TransformerBackend(ModuleBackend):
|
|
for name, buf in self.module.named_buffers():
|
|
for name, buf in self.module.named_buffers():
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
- self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
|
|
|
|
|
+ self.inference_pool = InferenceTaskPool(
|
|
|
|
+ self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
|
|
|
|
+ )
|
|
|
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
with torch.inference_mode():
|
|
with torch.inference_mode():
|