5
0
Эх сурвалжийг харах

priority in handlers and backend pools

Pavel Samygin 3 жил өмнө
parent
commit
1117867815

+ 89 - 3
src/server/backend.py

@@ -1,15 +1,21 @@
 """Code for serving bloom blocks via hivemind-server"""
-from queue import Empty
+import multiprocessing as mp
+import os
+import threading
+from concurrent.futures import Future
+from dataclasses import dataclass, field
+from queue import Empty, PriorityQueue
 from typing import Optional, Sequence, Tuple
 
 import torch
 from hivemind import use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
-from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils import InvalidStateError, get_logger
+from hivemind.moe.server.task_pool import Task, TaskPool
+from hivemind.utils import InvalidStateError, MPFuture, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
+from src.server.task_broker import SimpleBroker, TaskBrokerBase
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -17,6 +23,86 @@ logger = get_logger(__file__)
 MAX_LENGTH = 2048
 
 
+@dataclass(order=True)
+class PrioritizedTask:
+    value: int
+    task: Task = field(compare=False)
+
+
+class PrioritizedTaskPool(TaskPool):
+    def __init__(self, *args, broker: TaskBrokerBase = SimpleBroker(), **kwargs):
+        super().__init__(*args, **kwargs)
+        self.broker = broker
+        self.pollen_queue = mp.Queue(maxsize=self.tasks.maxsize)
+        self.priority_queue = PriorityQueue(maxsize=self.tasks.maxsize)
+
+    def submit_task(self, *args: torch.Tensor, pollen: float = 0.0) -> Future:
+        f = super().submit_task(*args)
+        self.pollen_queue.put(pollen)
+        return f
+
+    def _priortize_tasks(self):
+        """Infinite loop prioritizing incoming tasks"""
+        while True:
+            task = self.tasks.get(block=True)
+            pollen = self.pollen_queue.get(block=True)
+            self.priority_queue.put(PrioritizedTask(-self.broker(task, pollen), task), block=True)
+
+    def run(self, *args, **kwargs):
+        torch.set_num_threads(1)
+        logger.info(f"{self.name} starting, pid={os.getpid()}")
+        pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
+
+        output_thread = threading.Thread(
+            target=self._pool_output_loop, args=[pending_batches], name=f"{self.name}_output", daemon=True
+        )
+        priority_thread = threading.Thread(
+            target=self._priortize_tasks, args=[], name=f"{self.name}_priority", daemon=True
+        )
+
+        try:
+            output_thread.start()
+            priority_thread.start()
+            self._pool_input_loop(pending_batches, *args, **kwargs)
+        except KeyboardInterrupt:
+            logger.debug("Caught KeyboardInterrupt, shutting down")
+        finally:
+            output_thread.join()
+            priority_thread.join()
+
+    # TODO: this is a copy-paste of the original method, except that we use different queue
+    def iterate_minibatches(self, *args, **kwargs):
+        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
+        batch = []
+        total_size = 0
+
+        while True:
+            if total_size >= self.min_batch_size and self.priority_queue.empty():
+                yield batch
+                batch = []
+                total_size = 0
+            try:
+                logger.debug(f"{self.name} getting next task")
+                task = self.priority_queue.get(timeout=self.timeout)
+            except Empty:
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
+
+            task_size = self.get_task_size(task)
+
+            if total_size + task_size > self.max_batch_size:
+                yield batch
+                batch = []
+                total_size = 0
+
+            try:
+                if task.future.set_running_or_notify_cancel():
+                    batch.append(task)
+                    total_size += task_size
+            except InvalidStateError as e:
+                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
+
+
 class InferenceTaskPool(TaskPool):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)

+ 59 - 17
src/server/handler.py

@@ -1,5 +1,5 @@
 import contextlib
-from typing import AsyncIterator, Dict, List, Optional, Sequence, Union
+from typing import AsyncIterator, Dict, List, Sequence, Union, Tuple, Iterable
 
 import torch
 from hivemind import (
@@ -8,18 +8,18 @@ from hivemind import (
     P2PContext,
     TensorDescriptor,
     deserialize_torch_tensor,
+    deserialize_tensor_stream,
     nested_flatten,
     serialize_torch_tensor,
 )
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils import as_aiter
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, amap_in_executor, as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import MAX_LENGTH, TransformerBackend
+from src.server.backend import MAX_LENGTH, TransformerBackend, PrioritizedTaskPool
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -33,6 +33,29 @@ class TransformerConnectionHandler(ConnectionHandler):
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
 
+    async def _gather_inputs(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> Tuple[str, List[torch.Tensor], Dict]:
+        expert_uid, metadata = None, None
+
+        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+            nonlocal expert_uid
+
+            if expert_uid is None:
+                expert_uid = req.uid
+            elif expert_uid != req.uid:
+                raise ValueError("Expert uids differ in one request")
+
+            if metadata is None:
+                metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
+
+            return req.tensors
+
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await deserialize_tensor_stream(tensors_stream)
+        return expert_uid, inputs, metadata
+
+
     async def rpc_inference(
         self,
         requests: AsyncIterator[runtime_pb2.ExpertRequest],
@@ -56,6 +79,8 @@ class TransformerConnectionHandler(ConnectionHandler):
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+                    pollen = metadata.get("pollen", 0.0)
 
                     # Cast inputs to backend dtype
                     hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
@@ -66,8 +91,10 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert (
                             len(hidden_states) == 1 and hidden_states[0].ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-
-                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
+                        if isinstance(backend.inference_pool, PrioritizedTaskPool):
+                            hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states, pollen)
+                        else:
+                            hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
                         assert isinstance(hidden_states, (list, tuple))
                         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
 
@@ -92,8 +119,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+        pollen = metadata.get("pollen", 0.0)
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, pollen=pollen)
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize output and respond to client
@@ -108,11 +137,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         # Parse requests and prepare backends
-        uid_str, flat_inputs = await self._gather_inputs(requests, context)
+        uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, pollen=metadata.get("pollen", 0.0))
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize the overall output
@@ -133,8 +162,10 @@ class TransformerConnectionHandler(ConnectionHandler):
         flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+        pollen = metadata.get("pollen", 0.0)
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, pollen=pollen)
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -156,11 +187,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
 
-        uids_header, flat_tensors = await self._gather_inputs(requests, context)
+        uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, pollen=metadata.get("pollen", 0.0))
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -211,7 +242,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield handles
 
 
-async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
+async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], pollen: float = 0.0) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
@@ -237,7 +268,10 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
             hidden_states[:, :pre_seq_len] += prompt
-        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+        if isinstance(backend.forward_pool, PrioritizedTaskPool):
+            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, pollen)
+        else:
+            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
             hidden_states.ndim == 3
@@ -248,7 +282,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
 
 
 async def _rpc_backward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]
+    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], pollen: float = 0.0
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
     inputs, grad_outputs, *prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
@@ -271,7 +305,12 @@ async def _rpc_backward(
         if not is_dummy(prompt):
             inputs[:, :pre_seq_len] += prompt
         inter_inputs.append(inputs)
-        (inputs,) = await backend.forward_pool.submit_task(inputs)
+
+        if isinstance(backend.forward_pool, PrioritizedTaskPool):
+            (inputs,) = await backend.forward_pool.submit_task(inputs, pollen / 2.0)
+        else:
+            (inputs,) = await backend.forward_pool.submit_task(inputs)
+
         assert isinstance(inputs, torch.Tensor)
 
     if not is_dummy(prompts[-1]):
@@ -282,7 +321,10 @@ async def _rpc_backward(
     grad_prompts_reversed = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        if isinstance(backend.backward_pool, PrioritizedTaskPool):
+            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, pollen / 2.0)
+        else:
+            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
             grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))

+ 15 - 0
src/server/task_broker.py

@@ -0,0 +1,15 @@
+from abc import ABC, abstractmethod
+
+from hivemind.moe.server.task_pool import Task
+
+
+class TaskBrokerBase(ABC):
+    @abstractmethod
+    def __call__(self, task: Task, pollen: float) -> float:
+        pass
+
+
+class SimpleBroker(TaskBrokerBase):
+    def __call__(self, task: Task, pollen: float) -> float:
+        task_size = len(task.args[0]) if task.args else 1
+        return pollen / task_size