|
@@ -1,6 +1,10 @@
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
from queue import Empty
|
|
|
+<<<<<<< HEAD
|
|
|
from typing import Sequence, Tuple, Dict, Any, Optional
|
|
|
+=======
|
|
|
+from typing import Sequence, Tuple, Dict, Any
|
|
|
+>>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
|
|
|
|
|
|
import torch
|
|
|
from hivemind import use_hivemind_log_handler, BatchTensorDescriptor
|
|
@@ -14,6 +18,34 @@ from src.server.cache import MemoryCache
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
+<<<<<<< HEAD
|
|
|
+
|
|
|
+class InferenceTaskPool(TaskPool):
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
+ assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
|
|
|
+
|
|
|
+ def iterate_minibatches(self, *args, **kwargs):
|
|
|
+ """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ logger.debug(f"{self.name} getting next task")
|
|
|
+ task = self.tasks.get(timeout=self.timeout)
|
|
|
+ except Empty:
|
|
|
+ logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ if task.future.set_running_or_notify_cancel():
|
|
|
+ yield [task]
|
|
|
+ except InvalidStateError as e:
|
|
|
+ logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
|
|
|
+=======
|
|
|
+MAX_LENGTH = 2048
|
|
|
+>>>>>>> 79a9ff2b2ea0c2601e3670f9a28e84e8a511247d
|
|
|
+
|
|
|
|
|
|
class InferenceTaskPool(TaskPool):
|
|
|
def __init__(self, *args, **kwargs):
|