5
0
Эх сурвалжийг харах

Merge branch 'priority-tasks' of github.com:bigscience-workshop/petals into priority-tasks

 Conflicts:
	src/server/backend.py
justheuristic 2 жил өмнө
parent
commit
09d5533326

+ 36 - 6
src/client/inference_session.py

@@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
         max_length: int,
     ):
         self.uid, self.rpc_info = uid, rpc_info
+        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+    def step(
+        self,
+        new_hidden_states: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+    ):
+        """
+        Inference step: send a chunk of input tesors and receive a chunk of outputs
+        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
+        """
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
+            assert prompts.shape[2] <= new_hidden_states.shape[1]
+            assert prompts.shape[3] == new_hidden_states.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY
+        else:
+            assert len(hypo_ids) == len(new_hidden_states)
+            assert hypo_ids.dtype == torch.int64
+
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, prompts, hypo_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                     metadata=self._serialized_metadata if not self.stepped else None,
                 )
@@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession:
 
         return self
 
-    def step(self, inputs: torch.Tensor):
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
         assert not self.closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
         for session in self.inference_sessions:
-            outputs = session.step(inputs)
+            outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs

+ 3 - 2
src/client/remote_generation.py

@@ -105,11 +105,12 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                intermediate_prompts = None
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 

+ 19 - 11
src/server/backend.py

@@ -1,21 +1,21 @@
 """Code for serving bloom blocks via hivemind-server"""
-import ctypes
 import multiprocessing as mp
 import os
 import threading
 from concurrent.futures import Future
 from dataclasses import dataclass, field
 from queue import Empty, PriorityQueue
-from typing import Optional, Sequence, Tuple, Dict, Any, List
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 import torch
-from hivemind import use_hivemind_log_handler
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import Task, TaskPool
 from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
+from src.utils.misc import is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -119,23 +119,17 @@ class PrioritizedTaskPool(TaskPool):
     # TODO: this is a copy-paste of the original method, except that we use different queue
     def iterate_minibatches(self, *args, **kwargs):
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
-        print('IN iterate_minibatches')
         while True:
             try:
                 logger.debug(f"{self.name} getting next task")
                 task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
-                print('IN iterate_minibatches - 1')
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
-                print('IN iterate_minibatches - 2')
                 continue
 
-            print('IN iterate_minibatches - 3')
             try:
                 if task.task.future.set_running_or_notify_cancel():
-                    print('IN iterate_minibatches - 4')
-                    yield [task.task]
-                    print('IN iterate_minibatches - 5')
+                    yield [task]
             except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
 
@@ -158,18 +152,28 @@ class TransformerBackend(ModuleBackend):
         self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+        self.inference_schema = (
+            (
+                *self.args_schema,
+                BatchTensorDescriptor((), dtype=self.dtype),
+                BatchTensorDescriptor((), dtype=torch.int64),
+            ),
+            self.kwargs_schema,
+        )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            (hidden_states, hypo_ids) = inputs
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+                if not is_dummy(hypo_ids):
+                    cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
@@ -188,3 +192,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 114 - 57
src/server/handler.py

@@ -20,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import PrioritizedTaskPool, TransformerBackend
+from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -28,11 +29,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     module_backends: Dict[ModuleUID, TransformerBackend]
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        inference_max_length: int,
+        task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+    ):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
         self.inference_max_length = inference_max_length
+        self._prioritizer = task_prioritizer
 
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -69,13 +77,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
+            points = metadata.get("__points", 0.0)
 
             if not requested_uids:
                 raise ValueError("User must specify at least one block for inference, but got none")
             assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_inference should have number of points as a number or None, got {points}"
             if not 0 <= max_length <= self.inference_max_length:
                 raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
 
+            point_per_piece = points / max_length if max_length > 0 else 0.0
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
             cache_metadata = torch.tensor(
@@ -86,48 +99,67 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-                    dust = metadata.get("__dust", 0.0)
+                    hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+
+                    # Cast inputs to backend dtype
+                    hidden_states = hidden_states.to(requested_backends[0].dtype)
+                    assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
 
-                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+                    # parse deep prompts (optional argument)
+                    if prompts is None or is_dummy(prompts) or is_dummy(prompts):
+                        prompts = [DUMMY] * len(requested_backends)
+                    else:
+                        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
+                    if not (len(requested_backends) == len(prompts)):
+                        raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+                    length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
                     if prefix_length + length_increment > max_length:
                         raise ValueError(
                             f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                             f" exceeds pre-allocated maximum {max_length}"
                         )
 
-                    # Cast inputs to backend dtype
-                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
                     # run request tensors through all requested modules, update caches
-                    for backend, cache_handle in zip(requested_backends, cache_handles):
+                    for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                        if not is_dummy(prompt):
+                            hidden_states[:, : prompt.shape[1]] += prompt
+
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                        assert isinstance(
+                            hidden_states, torch.Tensor
+                        ), f"hidden states must be tensor, got {type(hidden_states)}"
                         assert (
-                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                            hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                        if isinstance(backend.inference_pool, PrioritizedTaskPool):
-                            hidden_states = await backend.inference_pool.submit_task(
-                                cache_metadata, *hidden_states, dust
-                            )
-                        else:
-                            hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
-                        assert isinstance(hidden_states, (list, tuple))
-                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                        assert isinstance(
+                            backend.inference_pool, PrioritizedTaskPool
+                        ), "petals support only prioritized pools"
+                        priority = self._prioritizer(
+                            cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
+                        )
+                        (hidden_states,) = await backend.inference_pool.submit_task(
+                            cache_metadata,
+                            hidden_states,
+                            hypo_ids,
+                            priority=priority,
+                            backend=backend,
+                            type="inference",
+                        )
 
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
                             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
-                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
                             )
                         ]
                     )
 
                     # prepare for next step
-                    prefix_length += hidden_states[0].shape[1]
+                    prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")
@@ -138,9 +170,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        dust = metadata.get("__dust", 0.0)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward should have number of points as number or None, got {points}"
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, dust=dust)
+        hidden_states = await _rpc_forward(
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize output and respond to client
@@ -158,9 +195,13 @@ class TransformerConnectionHandler(ConnectionHandler):
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
         hidden_states = await _rpc_forward(
-            *flat_inputs, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
         )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
 
@@ -183,9 +224,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        dust = metadata.get("__dust", 0.0)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward should have number of points as number or None, got {points}"
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, dust=dust)
+        grads = await _rpc_backward(
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -210,9 +256,13 @@ class TransformerConnectionHandler(ConnectionHandler):
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
         grads = await _rpc_backward(
-            *flat_tensors, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
         )
 
         # Modify grad_inputs_schema to support grad_prompts
@@ -267,7 +317,10 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
 async def _rpc_forward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: float = 0.0,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -277,27 +330,29 @@ async def _rpc_forward(
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    hidden_states, *prompts = flat_tensors
+    hidden_states, prompts = flat_tensors
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
-            hidden_states[:, :pre_seq_len] += prompt
-        if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, dust)
-        else:
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+            hidden_states[:, : prompt.shape[1]] += prompt
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            priority=priority,
+        )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
             hidden_states.ndim == 3
@@ -308,20 +363,20 @@ async def _rpc_forward(
 
 
 async def _rpc_backward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: float = 0.0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, *prompts = flat_tensors
+    inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
@@ -329,31 +384,33 @@ async def _rpc_backward(
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
-            inputs[:, :pre_seq_len] += prompt
+            inputs[:, : prompt.shape[1]] += prompt
         inter_inputs.append(inputs)
-
-        if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (inputs,) = await backend.forward_pool.submit_task(inputs, dust / 2.0)
-        else:
-            (inputs,) = await backend.forward_pool.submit_task(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
 
         assert isinstance(inputs, torch.Tensor)
 
     if not is_dummy(prompts[-1]):
-        inputs[:, :pre_seq_len] += prompts[-1]
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
     inter_inputs.append(inputs)
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     grad_prompts_reversed = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        if isinstance(backend.backward_pool, PrioritizedTaskPool):
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, dust / 2.0)
-        else:
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
-            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape

+ 3 - 2
src/server/task_prioritizer.py

@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod
 
+import torch
 from hivemind.moe.server.task_pool import Task
 
 
@@ -7,7 +8,7 @@ class TaskPrioritizerBase(ABC):
     """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
 
     @abstractmethod
-    def prioritize(self, task: Task, points: float, *args, **kwargs) -> float:
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         """Evaluates task value by the amout of points given"""
         pass
 
@@ -15,5 +16,5 @@ class TaskPrioritizerBase(ABC):
 class DummyTaskPrioritizer(TaskPrioritizerBase):
     """Simple implementation of DustBroker which counts amount of dust per task size"""
 
-    def __call__(self, task: Task, points: float, *args, **kwargs) -> float:
+    def __call__(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         return 0.0