Browse Source

Return multibatch mode

Artem Chumachenko 3 years ago
parent
commit
f10b05f912
3 changed files with 33 additions and 29 deletions
  1. 1 0
      src/client/remote_generation.py
  2. 1 28
      src/client/remote_model.py
  3. 31 1
      src/server/backend.py

+ 1 - 0
src/client/remote_generation.py

@@ -17,6 +17,7 @@ class RemoteGenerationMixin:
     This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
     """
 
+    @torch.no_grad()
     def generate(
         self,
         inputs: Optional[torch.Tensor] = None,

+ 1 - 28
src/client/remote_model.py

@@ -156,7 +156,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return transformer_outputs
 
 
-class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     config_class = DistributedBloomConfig
@@ -190,33 +190,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM, RemoteGenerationMixin):
             self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
             self.lm_head.bias[...] = new_lm_head.bias
 
-    def generate(
-        self,
-        inputs: Optional[torch.Tensor] = None,
-        do_sample: Optional[bool] = None,
-        temperature: float = 1.0,
-        top_k: Optional[int] = None,
-        top_p: Optional[float] = None,
-        eos_token_id: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        decoding_algorithm: Optional[DecodingAlgorithm] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-        **model_kwargs,
-    ) -> torch.Tensor:
-        return RemoteGenerationMixin.generate(
-            self,
-            inputs=inputs,
-            do_sample=do_sample,
-            temperature=temperature,
-            top_k=top_k,
-            top_p=top_p,
-            eos_token_id=eos_token_id,
-            max_new_tokens=max_new_tokens,
-            decoding_algorithm=decoding_algorithm,
-            provided_constraints=provided_constraints,
-            **model_kwargs,
-        )
-
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig

+ 31 - 1
src/server/backend.py

@@ -1,16 +1,46 @@
 """Code for serving bloom blocks via hivemind-server"""
 from typing import Sequence, Tuple
+from queue import Empty
 
 import torch
+from hivemind import use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
+from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
 
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
 MAX_LENGTH = 2048
 
 
+class InferenceTaskPool(TaskPool):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
+
+    def iterate_minibatches(self, *args, **kwargs):
+        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+
+        while True:
+            try:
+                logger.debug(f"{self.name} getting next task")
+                task = self.tasks.get(timeout=self.timeout)
+            except Empty:
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
+
+            try:
+                if task.future.set_running_or_notify_cancel():
+                    yield [task]
+            except InvalidStateError as e:
+                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
+
+
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
@@ -23,7 +53,7 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
+        self.inference_pool = InferenceTaskPool(self.inference_step, max_batch_size=4096, name=f"{self.name}_inference")
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():