瀏覽代碼

Priority tasks (#47)

* priority in handlers and backend pools
* simple points system on server side
* priortize task in handler before submit task
* fix tests
* s/expert/block/g

Co-authored-by: justheuristic <justheuristic@gmail.com>
Pavel Samygin 2 年之前
父節點
當前提交
50535a8435

+ 7 - 3
cli/run_server.py

@@ -31,15 +31,19 @@ def main():
     parser.add_argument('--num_handlers', type=int, default=8, required=False,
                         help='server will use this many processes to handle incoming requests')
     parser.add_argument('--min_batch_size', type=int, default=1,
-                        help='Minimum required batch size for all expert operations')
+                        help='Minimum required batch size for all operations (in total tokens)')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of tokens in the same batch will not exceed this value')
+    parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
+                        help='Pre-form this many subsequent batches while GPU is processing the current one')
+    parser.add_argument('--sender_threads', type=int, default=1, required=False,
+                        help='Use this many threads to pass results/exceptions from Runtime to Pools')
     parser.add_argument('--inference_max_length', type=int, default=16384,
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
     parser.add_argument('--cache_dir', type=str, default=None, 
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
     parser.add_argument('--device', type=str, default=None, required=False,
-                        help='all experts will use this device in torch notation; default: cuda if available else cpu')
+                        help='all blocks will use this device in torch notation; default: cuda if available else cpu')
     parser.add_argument("--torch_dtype", type=str, default="auto",
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
@@ -58,7 +62,7 @@ def main():
                              'on the first run and uses these estimates for future runs. '
                              'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
     parser.add_argument('--update_period', type=float, required=False, default=30,
-                        help='Server will report experts to DHT once in this many seconds')
+                        help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],

+ 1 - 0
src/client/__init__.py

@@ -2,3 +2,4 @@ from src.client.inference_session import RemoteSequentialInferenceSession, Remot
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
+from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 2 - 1
src/client/inference_session.py

@@ -43,6 +43,7 @@ class RemoteTransformerBlockInferenceSession:
         outputs_aiter: AsyncIterator,
         *,
         max_length: int,
+        points: int = 0,
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
@@ -50,7 +51,7 @@ class RemoteTransformerBlockInferenceSession:
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
+        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
         self.stepped = False
         self.closed = False
 

+ 156 - 0
src/client/remote_forward_backward.py

@@ -0,0 +1,156 @@
+"""
+Utility functions that call RPC forward or backward on a single remote server
+"""
+import asyncio
+from typing import Iterable, List, Sequence, Tuple
+
+import torch
+from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
+from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
+from hivemind.p2p import StubBase
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.streaming import split_for_streaming
+
+from src.data_structures import ModuleUID, RPCInfo
+
+
+async def run_remote_forward(
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
+) -> Tuple[torch.Tensor, ...]:
+    """
+    Serializes input tensors and calls "rpc_forward" on a remote server.
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+    # detach to avoid pickling the computation graph
+    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
+    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
+
+    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+    forward_inputs = (inputs, kwargs)
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
+        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+    forward_inputs = nested_flatten(forward_inputs)
+    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
+        )
+    )
+
+    # call RPC on remote server
+    size = sum(t.element_size() * t.nelement() for t in inputs)
+    if size > MAX_UNARY_PAYLOAD_SIZE:
+        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
+    else:
+        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
+
+    return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
+
+
+async def _forward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
+
+    outputs = await stub.rpc_forward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
+            iter_as_aiter(split),
+        ),
+    )
+
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _forward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def _backward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+    grad_inputs = await stub.rpc_backward_stream(
+        amap_in_executor(
+            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
+            iter_as_aiter(split),
+        ),
+    )
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await deserialize_tensor_stream(tensors_stream)
+
+
+async def run_remote_backward(
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    inputs: torch.Tensor,
+    grad_outputs: List[torch.Tensor],
+    *extra_tensors: torch.Tensor,
+    **kwargs,
+) -> Sequence[torch.Tensor]:
+    """
+    Serializes grad outputs and calls "rpc_backward" on a remote server.
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+    """
+
+    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    # TODO generalize this
+    prompts_schema = next(iter(args_schema))
+    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
+
+    # Asynchronous serialization
+    loop = asyncio.get_running_loop()
+    serialized_tensors = await asyncio.gather(
+        *(
+            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
+    )
+
+    size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
+    if size > MAX_UNARY_PAYLOAD_SIZE:
+        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
+    else:
+        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
+
+    return deserialized_grad_inputs
+
+
+async def _backward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

+ 3 - 1
src/client/sequence_manager.py

@@ -9,6 +9,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+from src.client.spending_policy import NoSpendingPolicy
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.server.handler import TransformerConnectionHandler
@@ -24,6 +25,7 @@ class RemoteSequenceManager:
     """
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
+        assert len(block_uids) > 0, "Sequences must contain at least one block"
         self.dht, self.p2p = dht, p2p
         self.block_uids: List[ModuleUID] = list(block_uids)
         self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
@@ -39,7 +41,7 @@ class RemoteSequenceManager:
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
 
-    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
 

+ 10 - 89
src/client/sequential_autograd.py

@@ -1,102 +1,22 @@
+"""
+A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
+"""
 import asyncio
 import logging
 from typing import List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import serialize_torch_tensor
-from hivemind.moe.client.expert import expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import StubBase
-from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 
+from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
 from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from src.server.handler import TransformerConnectionHandler
 from src.utils.misc import DUMMY, is_dummy
 
 MAX_TOKENS_IN_BATCH = 1024
 
 
-async def run_expert_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
-) -> Tuple[torch.Tensor, ...]:
-    """
-    Serializes input tensors and calls "expert_forward".
-    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
-    """
-
-    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-    # detach to avoid pickling the computation graph
-    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
-    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
-
-    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-    forward_inputs = (inputs, kwargs)
-
-    # Modify forward_schema to support prompts
-    args_schema, kwargs_schema = rpc_info["forward_schema"]
-    # TODO: rm this assert when support arbitrary number of input tensors
-    assert len(args_schema) == 1 and len(inputs) == 2
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
-
-    if not nested_compare(forward_inputs, forward_schema_with_prompts):
-        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
-
-    forward_inputs = nested_flatten(forward_inputs)
-    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
-
-    # Asynchronous serialization
-    loop = asyncio.get_running_loop()
-    serialized_tensors = await asyncio.gather(
-        *(
-            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
-        )
-    )
-
-    deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
-    flat_outputs = tuple(deserialized_outputs)
-    return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
-
-
-async def run_expert_backward(
-    uid: ModuleUID,
-    stub: StubBase,
-    rpc_info: RPCInfo,
-    inputs: torch.Tensor,
-    grad_outputs: List[torch.Tensor],
-    *extra_tensors: torch.Tensor,
-) -> Sequence[torch.Tensor]:
-    """
-    Serializes grad outputs and calls "expert_backward".
-    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
-    """
-
-    grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
-
-    # Modify forward_schema to support prompts
-    args_schema, kwargs_schema = rpc_info["forward_schema"]
-    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
-    # TODO generalize this
-    prompts_schema = next(iter(args_schema))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
-
-    # Asynchronous serialization
-    loop = asyncio.get_running_loop()
-    serialized_tensors = await asyncio.gather(
-        *(
-            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
-        )
-    )
-
-    deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
-    return deserialized_grad_inputs
-
-
 async def sequential_forward(
     inputs: torch.Tensor,
     prompts: torch.Tensor,
@@ -121,16 +41,17 @@ async def sequential_forward(
     sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
     done_sequences = []
+    outputs = inputs
 
     while len(sequences) > 0:
         while True:
+            span = sequences.pop(0)
+            span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
-                span = sequences.pop(0)
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
-                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
+                (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -171,7 +92,7 @@ async def sequential_backward(
             span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                grad_outputs, *span_grad_prompts = await run_expert_backward(
+                grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
                 )
                 grad_outputs = [grad_outputs]

+ 14 - 0
src/client/spending_policy.py

@@ -0,0 +1,14 @@
+from abc import ABC, abstractmethod
+
+from hivemind.proto.runtime_pb2 import ExpertRequest
+
+
+class SpendingPolicyBase(ABC):
+    @abstractmethod
+    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+        pass
+
+
+class NoSpendingPolicy(SpendingPolicyBase):
+    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+        return 0.0

+ 13 - 31
src/server/backend.py

@@ -1,45 +1,20 @@
 """Code for serving bloom blocks via hivemind-server"""
-from queue import Empty
 from typing import Any, Dict, Optional, Sequence, Tuple
 
 import torch
 from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
-from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils import InvalidStateError, get_logger
+from hivemind.utils import get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
+from src.server.task_pool import PrioritizedTaskPool
 from src.utils.misc import is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class InferenceTaskPool(TaskPool):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-        assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
-
-    def iterate_minibatches(self, *args, **kwargs):
-        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
-
-        while True:
-            try:
-                logger.debug(f"{self.name} getting next task")
-                task = self.tasks.get(timeout=self.timeout)
-            except Empty:
-                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
-                continue
-
-            try:
-                if task.future.set_running_or_notify_cancel():
-                    yield [task]
-            except InvalidStateError as e:
-                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
-
-
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
@@ -52,8 +27,15 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = InferenceTaskPool(
-            self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
+        max_batch_size = self.forward_pool.max_batch_size
+        self.inference_pool = PrioritizedTaskPool(
+            self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
+        )
+        self.forward_pool = PrioritizedTaskPool(
+            self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
+        )
+        self.backward_pool = PrioritizedTaskPool(
+            self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
         )
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
         self.inference_schema = (
@@ -94,9 +76,9 @@ class TransformerBackend(ModuleBackend):
                 cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
                 return (hidden_states,)
 
-    def get_pools(self) -> Sequence[TaskPool]:
+    def get_pools(self) -> Sequence[PrioritizedTaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
 
     def get_info(self) -> Dict[str, Any]:
-        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
         return dict(super().get_info(), inference_schema=self.inference_schema)

+ 118 - 18
src/server/handler.py

@@ -1,5 +1,5 @@
 import contextlib
-from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
+from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
 
 import torch
 from hivemind import (
@@ -7,6 +7,7 @@ from hivemind import (
     MSGPackSerializer,
     P2PContext,
     TensorDescriptor,
+    deserialize_tensor_stream,
     deserialize_torch_tensor,
     nested_flatten,
     serialize_torch_tensor,
@@ -14,12 +15,13 @@ from hivemind import (
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils import as_aiter
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import TransformerBackend
+from src.server.task_pool import PrioritizedTaskPool
+from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -28,11 +30,41 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     module_backends: Dict[ModuleUID, TransformerBackend]
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        inference_max_length: int,
+        task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+    ):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
         self.inference_max_length = inference_max_length
+        self._prioritizer = task_prioritizer
+
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> Tuple[str, List[torch.Tensor], Dict]:
+        block_uid, metadata = None, None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal block_uid, metadata
+
+            if block_uid is None:
+                block_uid = req.uid
+            elif block_uid != req.uid:
+                raise ValueError("Block uids differ in one request")
+
+            if metadata is None:
+                metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
+
+            return req.tensors
+
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await deserialize_tensor_stream(tensors_stream)
+        assert isinstance(block_uid, str) and isinstance(metadata, dict)
+        return block_uid, inputs, metadata
 
     async def rpc_inference(
         self,
@@ -47,13 +79,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
+            points = metadata.get("points", 0)
 
             if not requested_uids:
                 raise ValueError("User must specify at least one block for inference, but got none")
             assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_inference should have number of points as a number or None, got {points}"
             if not 0 <= max_length <= self.inference_max_length:
                 raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
 
+            point_per_piece = points / max_length if max_length > 0 else 0.0
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
             cache_metadata = torch.tensor(
@@ -98,8 +135,19 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert (
                             hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+                        assert isinstance(
+                            backend.inference_pool, PrioritizedTaskPool
+                        ), "petals support only prioritized pools"
+                        priority = self._prioritizer.prioritize(
+                            cache_metadata,
+                            hidden_states,
+                            hypo_ids,
+                            points=point_per_piece / len(requested_backends),
+                            backend=backend,
+                            type="inference",
+                        )
                         (hidden_states,) = await backend.inference_pool.submit_task(
-                            cache_metadata, hidden_states, hypo_ids
+                            cache_metadata, hidden_states, hypo_ids, priority=priority
                         )
 
                     # serialize and send last layer outputs
@@ -123,8 +171,15 @@ class TransformerConnectionHandler(ConnectionHandler):
         flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+        points = metadata.get("points", 0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward should have number of points as number or None, got {points}"
+
+        hidden_states = await _rpc_forward(
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize output and respond to client
@@ -139,11 +194,17 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         # Parse requests and prepare backends
-        uid_str, flat_inputs = await self._gather_inputs(requests, context)
+        uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("points", 0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        hidden_states = await _rpc_forward(
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
 
         # Serialize the overall output
@@ -164,8 +225,15 @@ class TransformerConnectionHandler(ConnectionHandler):
         flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
-
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+        metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+        points = metadata.get("points", 0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward should have number of points as number or None, got {points}"
+
+        grads = await _rpc_backward(
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -187,11 +255,17 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
 
-        uids_header, flat_tensors = await self._gather_inputs(requests, context)
+        uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("points", 0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+        grads = await _rpc_backward(
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -244,7 +318,12 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield handles
 
 
-async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
+async def _rpc_forward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
@@ -267,7 +346,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
             hidden_states[:, : prompt.shape[1]] += prompt
-        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            priority=priority,
+        )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
             hidden_states.ndim == 3
@@ -278,7 +365,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
 
 
 async def _rpc_backward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
     inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
@@ -298,7 +388,12 @@ async def _rpc_backward(
         if not is_dummy(prompt):
             inputs[:, : prompt.shape[1]] += prompt
         inter_inputs.append(inputs)
-        (inputs,) = await backend.forward_pool.submit_task(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+
         assert isinstance(inputs, torch.Tensor)
 
     if not is_dummy(prompts[-1]):
@@ -309,7 +404,12 @@ async def _rpc_backward(
     grad_prompts_reversed = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))

+ 198 - 0
src/server/runtime.py

@@ -0,0 +1,198 @@
+import multiprocessing as mp
+import multiprocessing.pool
+import threading
+from collections import defaultdict
+from itertools import chain
+from queue import SimpleQueue
+from selectors import EVENT_READ, DefaultSelector
+from statistics import mean
+from time import time
+from typing import Dict, NamedTuple, Optional
+
+import torch
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.utils import get_logger
+from prefetch_generator import BackgroundGenerator
+
+logger = get_logger(__name__)
+
+
+class Runtime(threading.Thread):
+    """
+    A group of processes that processes incoming requests for multiple module backends on a shared device.
+    Runtime is usually created and managed by Server, humans need not apply.
+
+    For debugging, you can start runtime manually with .start() or .run()
+
+    >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
+    >>> runtime = Runtime(module_backends)
+    >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
+    >>> runtime.ready.wait()  # await for runtime to load all blocks on device and create request pools
+    >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
+    >>> print("Returned:", future.result())
+    >>> runtime.shutdown()
+
+    :param module_backends: a dict [block uid -> ModuleBackend]
+    :param prefetch_batches: form up to this many batches in advance
+    :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
+    :param device: if specified, moves all blocks and data to this device via .to(device=device).
+      If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
+
+    :param stats_report_interval: interval to collect and log statistics about runtime performance
+    """
+
+    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
+
+    def __init__(
+        self,
+        module_backends: Dict[str, ModuleBackend],
+        prefetch_batches: int = 1,
+        sender_threads: int = 1,
+        device: torch.device = None,
+        stats_report_interval: Optional[int] = None,
+    ):
+        super().__init__()
+        self.module_backends = module_backends
+        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
+        self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+        self.shutdown_trigger = mp.Event()
+        self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
+
+        self.stats_report_interval = stats_report_interval
+        if self.stats_report_interval is not None:
+            self.stats_reporter = StatsReporter(self.stats_report_interval)
+
+    def run(self):
+        for pool in self.pools:
+            if not pool.is_alive():
+                pool.start()
+        if self.device is not None:
+            for backend in self.module_backends.values():
+                backend.module.to(self.device)
+
+        with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
+            try:
+                self.ready.set()
+                if self.stats_report_interval is not None:
+                    self.stats_reporter.start()
+                logger.info("Started")
+
+                batch_iterator = self.iterate_minibatches_from_pools()
+                if self.prefetch_batches > 0:
+                    batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
+
+                for pool, batch_index, batch in batch_iterator:
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
+
+                    start = time()
+                    try:
+                        outputs = pool.process_func(*batch)
+                        output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+
+                        batch_processing_time = time() - start
+
+                        batch_size = outputs[0].size(0)
+                        logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+
+                        if self.stats_report_interval is not None:
+                            self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+
+                    except KeyboardInterrupt:
+                        raise
+                    except BaseException as exception:
+                        logger.exception(f"Caught {exception}, attempting to recover")
+                        output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
+
+            finally:
+                if not self.shutdown_trigger.is_set():
+                    self.shutdown()
+
+    def shutdown(self):
+        """Gracefully terminate a running runtime."""
+        logger.info("Shutting down")
+        self.ready.clear()
+
+        if self.stats_report_interval is not None:
+            self.stats_reporter.stop.set()
+            self.stats_reporter.join()
+
+        logger.debug("Terminating pools")
+        for pool in self.pools:
+            if pool.is_alive():
+                pool.shutdown()
+        logger.debug("Pools terminated")
+
+        # trigger background thread to shutdown
+        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
+        self.shutdown_trigger.set()
+
+    def iterate_minibatches_from_pools(self, timeout=None):
+        """
+        Chooses pool according to priority, then copies exposed batch and frees the buffer
+        """
+        with DefaultSelector() as selector:
+            for pool in self.pools:
+                selector.register(pool.batch_receiver, EVENT_READ, pool)
+            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+
+            while True:
+                # wait until at least one batch_receiver becomes available
+                logger.debug("Waiting for inputs from task pools")
+                ready_fds = selector.select()
+                ready_objects = {key.data for (key, events) in ready_fds}
+                if self.SHUTDOWN_TRIGGER in ready_objects:
+                    break  # someone asked us to shutdown, break from the loop
+
+                logger.debug("Choosing the pool with first priority")
+
+                pool = min(ready_objects, key=lambda pool: pool.priority)
+
+                logger.debug(f"Loading batch from {pool.name}")
+                batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+                logger.debug(f"Loaded batch from {pool.name}")
+                yield pool, batch_index, batch_tensors
+
+
+BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
+
+
+class StatsReporter(threading.Thread):
+    def __init__(self, report_interval: int):
+        super().__init__()
+        self.report_interval = report_interval
+        self.stop = threading.Event()
+        self.stats_queue = SimpleQueue()
+
+    def run(self):
+        while not self.stop.wait(self.report_interval):
+            pool_batch_stats = defaultdict(list)
+            while not self.stats_queue.empty():
+                pool_uid, batch_stats = self.stats_queue.get()
+                pool_batch_stats[pool_uid].append(batch_stats)
+
+            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
+            logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
+            for pool_uid, pool_stats in pool_batch_stats.items():
+                total_batches = len(pool_stats)
+                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
+                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
+                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
+                batches_to_time = total_batches / total_time
+                batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
+
+                examples_to_time = total_examples / total_time
+                example_performance = f"{examples_to_time:.2f} " + (
+                    "examples/s" if examples_to_time > 1 else "s/example"
+                )
+
+                logger.info(
+                    f"{pool_uid}: "
+                    f"{total_batches} batches ({batch_performance}), "
+                    f"{total_examples} examples ({example_performance}), "
+                    f"avg batch size {avg_batch_size:.2f}"
+                )
+
+    def report_stats(self, pool_uid, batch_size, processing_time):
+        batch_stats = BatchStats(batch_size, processing_time)
+        self.stats_queue.put_nowait((pool_uid, batch_stats))

+ 6 - 2
src/server/server.py

@@ -71,9 +71,9 @@ class Server(threading.Thread):
         runs Runtime (self.runtime) to process incoming requests.
         """
         logger.info(f"Serving {len(self.module_backends)} blocks:")
-        for expert_name, backend in self.module_backends.items():
+        for block_name, backend in self.module_backends.items():
             num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
-            logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+            logger.info(f"{block_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
 
         if not self.dht.is_alive():
             self.dht.run_in_background(await_ready=True)
@@ -118,6 +118,8 @@ class Server(threading.Thread):
         custom_module_path=None,
         update_period: float = 30,
         expiration: Optional[float] = None,
+        prefetch_batches: int = 1,
+        sender_threads: int = 1,
         max_block_selection_delay: float = 1,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
@@ -236,6 +238,8 @@ class Server(threading.Thread):
             stats_report_interval=stats_report_interval,
             update_period=update_period,
             expiration=expiration,
+            prefetch_batches=prefetch_batches,
+            sender_threads=sender_threads,
             start=start,
         )
 

+ 175 - 0
src/server/task_pool.py

@@ -0,0 +1,175 @@
+import ctypes
+import multiprocessing as mp
+import threading
+import time
+from dataclasses import dataclass, field
+from queue import PriorityQueue
+from typing import Any, Generator, List, Optional, Sequence, Tuple
+
+import torch
+from hivemind import MPFuture, get_logger, use_hivemind_log_handler
+from hivemind.moe.server.task_pool import TaskPoolBase
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@dataclass(order=True, frozen=True)
+class Task:
+    priority: float
+    time_submitted: float
+    future: MPFuture = field(compare=False)
+    args: Sequence[torch.Tensor] = field(compare=False)
+
+    @property
+    def uid(self) -> int:
+        return self.future._uid
+
+
+class PrioritizedTaskPool(TaskPoolBase):
+    """
+    Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
+    returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
+    A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
+
+    :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
+      This would require grouping requests of different length.
+
+    :param process_func: function to be applied to every formed batch; called by Runtime
+        Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
+    :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+         Measured in the total number of tokens (i.e. batch size * sequence length)
+
+    :param name: pool name, used for logging
+    :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+    :param start: if True, start automatically at the end of __init__
+    """
+
+    def __init__(
+        self,
+        process_func: callable,
+        max_batch_size: int,
+        name: str,
+        min_batch_size=1,
+        daemon=True,
+        start=False,
+    ):
+        super().__init__(process_func, daemon=daemon, name=name)
+        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+
+        self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers
+        self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime
+
+        self._prioritizer_thread = threading.Thread(
+            name=self.name + "_prioritizer",
+            target=self._prioritize_tasks,
+            args=[self.submitted_tasks, self._ordered_tasks],
+            daemon=True,
+        )
+        self._dispatched_tasks = {}
+        self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
+        self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
+        self.priority = float("inf"), float("inf")  # (first task priority, first task timestamp)
+        if start:
+            self.start()
+
+    @staticmethod
+    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+        """Read tasks from incoming queue and put them into a local priority queue"""
+        while True:
+            task = submitted_tasks.get()
+            if task is None:
+                logger.debug("Shutting down prioritizer thread")
+                break
+
+            ordered_tasks.put(task, block=True)
+
+    def start(self):
+        assert not self.is_alive() and not self._prioritizer_thread.is_alive()
+        self._prioritizer_thread.start()
+        super().start()
+
+    def shutdown(self, timeout: Optional[float] = None):
+        self.submitted_tasks.put(None)
+        self.terminate()
+        self._prioritizer_thread.join(timeout)
+
+    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
+        """Add task to this pool's queue, return Future for its output"""
+        task = Task(priority, time.monotonic(), MPFuture(), args)
+        if self.get_task_size(task) > self.max_batch_size:
+            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
+            task.future.set_exception(exc)
+        else:
+            self.submitted_tasks.put(task)
+            self.batch_sender.send(None)  # use this pipe to count the number of unfinished batches
+            if (task.priority, task.time_submitted) < self.priority:
+                self.priority = (task.priority, task.time_submitted)
+        return task.future
+
+    def get_task_size(self, task: Task) -> int:
+        """compute task processing complexity; defaults to the total number of tokens"""
+        if task.args and task.args[0].ndim >= 2:
+            return task.args[0].shape[0] * task.args[0].shape[1]
+        return 1
+
+    def load_batch_to_runtime(
+        self, timeout: Optional[float] = None, device: Optional[torch.device] = None
+    ) -> Tuple[Any, List[torch.Tensor]]:
+        """receive next batch of arrays"""
+        task = self._ordered_tasks.get(block=True, timeout=timeout)
+        batch_inputs = [
+            tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
+        ]
+        self._dispatched_tasks[task.uid] = task
+        self.batch_receiver.recv()  # reduce the number of active batches
+        if not self._ordered_tasks.empty():
+            first_remaining_task: Task = self._ordered_tasks.queue[0]
+            self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
+        return task.uid, batch_inputs
+
+    def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
+        """send results for a processed batch, previously loaded through load_batch_to_runtime"""
+        batch_outputs = [
+            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+            for tensor in batch_outputs
+        ]
+
+        task = self._dispatched_tasks.pop(uid, None)
+        if task is None:
+            logger.error(
+                f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
+            )
+        else:
+            task.future.set_result(batch_outputs)
+
+    def send_exception_from_runtime(self, uid: int, exception: BaseException):
+        task = self._dispatched_tasks.pop(uid, None)
+        if task is None:
+            logger.error(
+                f"Internal error: task task with index {uid} is missing from the dictionary; "
+                f"Could not set exception {exception}"
+            )
+        else:
+            task.future.set_exception(exception)
+
+    def run(self, *args, **kwargs):
+        mp.Event().wait()
+
+    @property
+    def empty(self):
+        return not self.batch_receiver.poll()
+
+    @property
+    def priority(self) -> Tuple[float, float]:
+        """The priority of this pool equals the (priority, timestamp) of the most important task in it."""
+        return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
+
+    @priority.setter
+    def priority(self, item: Tuple[float, float]):
+        assert len(item) == 2
+        self._priority.value = float(item[0])
+        self._oldest_undispatched_timestamp.value = float(item[1])
+
+    def iterate_minibatches(self, *args, **kwargs) -> Generator[List[Task], None, None]:
+        raise NotImplementedError()

+ 20 - 0
src/server/task_prioritizer.py

@@ -0,0 +1,20 @@
+from abc import ABC, abstractmethod
+
+import torch
+from hivemind.moe.server.task_pool import Task
+
+
+class TaskPrioritizerBase(ABC):
+    """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
+
+    @abstractmethod
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+        """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
+        pass
+
+
+class DummyTaskPrioritizer(TaskPrioritizerBase):
+    """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
+
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+        return 0.0

+ 71 - 0
tests/test_priority_pool.py

@@ -0,0 +1,71 @@
+import multiprocessing as mp
+import time
+
+import pytest
+import torch
+
+from src.server.runtime import Runtime
+from src.server.task_pool import PrioritizedTaskPool
+
+
+@pytest.mark.forked
+def test_priority_pools():
+    outputs_queue = mp.SimpleQueue()
+    results_valid = mp.Event()
+
+    def dummy_pool_func(x):
+        time.sleep(0.1)
+        y = x**2
+        outputs_queue.put((x, y))
+        return (y,)
+
+    class DummyBackend:
+        def __init__(self, pools):
+            self.pools = pools
+
+        def get_pools(self):
+            return self.pools
+
+    pools = (
+        PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
+        PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
+    )
+
+    runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+    runtime.start()
+
+    def process_tasks():
+        futures = []
+        futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
+        futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
+        time.sleep(0.01)
+        futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
+        futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
+        futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
+        futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
+        futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
+        futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
+        futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
+        for i, f in enumerate(futures):
+            assert f.result()[0].item() == i**2
+        results_valid.set()
+
+    proc = mp.Process(target=process_tasks)
+    proc.start()
+    proc.join()
+    assert results_valid.is_set()
+
+    ordered_outputs = []
+    while not outputs_queue.empty():
+        ordered_outputs.append(outputs_queue.get()[0].item())
+
+    assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
+    #                          0 - first batch is loaded immediately, before everything else
+    #                             5 - highest priority task overall
+    #                                1 - first of several tasks with equal lowest priority (1)
+    #                                   2 - second earliest task with priority 1, fetched from pool B
+    #                                      6 - third earliest task with priority 1, fetched from pool A again
+    #                                         8 - last priority-1 task, pool B
+    #                                            3 - task with priority 2 from pool A
+    #                                               4 - task with priority 10 from pool A
+    #                                                  7 - task with priority 11 from pool B