Bladeren bron

Merge branch 'priority-tasks' of github.com:bigscience-workshop/petals into priority-tasks

 Conflicts:
	src/server/backend.py
justheuristic 3 jaren geleden
bovenliggende
commit
09d5533326
5 gewijzigde bestanden met toevoegingen van 175 en 78 verwijderingen
  1. 36 6
      src/client/inference_session.py
  2. 3 2
      src/client/remote_generation.py
  3. 19 11
      src/server/backend.py
  4. 114 57
      src/server/handler.py
  5. 3 2
      src/server/task_prioritizer.py

+ 36 - 6
src/client/inference_session.py

@@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
         max_length: int,
         max_length: int,
     ):
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.uid, self.rpc_info = uid, rpc_info
+        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
                 break  # this message means "done sending"
 
 
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+    def step(
+        self,
+        new_hidden_states: torch.Tensor,
+        prompts: Optional[torch.Tensor] = None,
+        hypo_ids: Optional[torch.Tensor] = None,
+    ):
+        """
+        Inference step: send a chunk of input tesors and receive a chunk of outputs
+        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
+        """
         if self.closed:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
             raise Exception("Session is closed, cannot perform step")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
+            assert prompts.shape[2] <= new_hidden_states.shape[1]
+            assert prompts.shape[3] == new_hidden_states.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY
+        else:
+            assert len(hypo_ids) == len(new_hidden_states)
+            assert hypo_ids.dtype == torch.int64
+
         # serialize inputs and put them into the queue
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, prompts, hypo_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
             self._step(
                 runtime_pb2.ExpertRequest(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     uid=self.uid,
                     tensors=[
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                     ],
                     metadata=self._serialized_metadata if not self.stepped else None,
                     metadata=self._serialized_metadata if not self.stepped else None,
                 )
                 )
@@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession:
 
 
         return self
         return self
 
 
-    def step(self, inputs: torch.Tensor):
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
         assert not self.closed
         assert not self.closed
         if torch.is_grad_enabled():
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
         for session in self.inference_sessions:
         for session in self.inference_sessions:
-            outputs = session.step(inputs)
+            outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
             inputs = outputs
         return inputs
         return inputs

+ 3 - 2
src/client/remote_generation.py

@@ -105,11 +105,12 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
                 embs = self.transformer.word_embeddings(outputs[-1])
+                intermediate_prompts = None
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 
 

+ 19 - 11
src/server/backend.py

@@ -1,21 +1,21 @@
 """Code for serving bloom blocks via hivemind-server"""
 """Code for serving bloom blocks via hivemind-server"""
-import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import os
 import os
 import threading
 import threading
 from concurrent.futures import Future
 from concurrent.futures import Future
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from queue import Empty, PriorityQueue
 from queue import Empty, PriorityQueue
-from typing import Optional, Sequence, Tuple, Dict, Any, List
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 
 import torch
 import torch
-from hivemind import use_hivemind_log_handler
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import Task, TaskPool
 from hivemind.moe.server.task_pool import Task, TaskPool
 from hivemind.utils import InvalidStateError, get_logger
 from hivemind.utils import InvalidStateError, get_logger
 
 
 from src.bloom.from_pretrained import BloomBlock
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
 from src.server.cache import MemoryCache
+from src.utils.misc import is_dummy
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -119,23 +119,17 @@ class PrioritizedTaskPool(TaskPool):
     # TODO: this is a copy-paste of the original method, except that we use different queue
     # TODO: this is a copy-paste of the original method, except that we use different queue
     def iterate_minibatches(self, *args, **kwargs):
     def iterate_minibatches(self, *args, **kwargs):
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
-        print('IN iterate_minibatches')
         while True:
         while True:
             try:
             try:
                 logger.debug(f"{self.name} getting next task")
                 logger.debug(f"{self.name} getting next task")
                 task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
                 task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
-                print('IN iterate_minibatches - 1')
             except Empty:
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
-                print('IN iterate_minibatches - 2')
                 continue
                 continue
 
 
-            print('IN iterate_minibatches - 3')
             try:
             try:
                 if task.task.future.set_running_or_notify_cancel():
                 if task.task.future.set_running_or_notify_cancel():
-                    print('IN iterate_minibatches - 4')
-                    yield [task.task]
-                    print('IN iterate_minibatches - 5')
+                    yield [task]
             except InvalidStateError as e:
             except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
                 logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
 
 
@@ -158,18 +152,28 @@ class TransformerBackend(ModuleBackend):
         self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
         self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
         self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+        self.inference_schema = (
+            (
+                *self.args_schema,
+                BatchTensorDescriptor((), dtype=self.dtype),
+                BatchTensorDescriptor((), dtype=torch.int64),
+            ),
+            self.kwargs_schema,
+        )
 
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            (hidden_states, hypo_ids) = inputs
             assert (
             assert (
                 hidden_states.ndim == 3
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+                if not is_dummy(hypo_ids):
+                    cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
                 hidden_states, (new_k, new_v) = self.module.forward(
@@ -188,3 +192,7 @@ class TransformerBackend(ModuleBackend):
 
 
     def get_pools(self) -> Sequence[TaskPool]:
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 114 - 57
src/server/handler.py

@@ -20,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import PrioritizedTaskPool, TransformerBackend
 from src.server.backend import PrioritizedTaskPool, TransformerBackend
+from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from src.utils.misc import DUMMY, is_dummy
 from src.utils.misc import DUMMY, is_dummy
 
 
 
 
@@ -28,11 +29,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
     module_backends: Dict[ModuleUID, TransformerBackend]
     module_backends: Dict[ModuleUID, TransformerBackend]
 
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        inference_max_length: int,
+        task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+    ):
         super().__init__(dht, module_backends)
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
             assert isinstance(module_backend, TransformerBackend)
         self.inference_max_length = inference_max_length
         self.inference_max_length = inference_max_length
+        self._prioritizer = task_prioritizer
 
 
     async def _gather_inputs(
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -69,13 +77,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
             max_length = metadata.get("max_length")
+            points = metadata.get("__points", 0.0)
 
 
             if not requested_uids:
             if not requested_uids:
                 raise ValueError("User must specify at least one block for inference, but got none")
                 raise ValueError("User must specify at least one block for inference, but got none")
             assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
             assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_inference should have number of points as a number or None, got {points}"
             if not 0 <= max_length <= self.inference_max_length:
             if not 0 <= max_length <= self.inference_max_length:
                 raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
                 raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
 
 
+            point_per_piece = points / max_length if max_length > 0 else 0.0
             batch_size = request.tensors[0].size[0] if request.tensors else 1
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
 
             cache_metadata = torch.tensor(
             cache_metadata = torch.tensor(
@@ -86,48 +99,67 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-                    dust = metadata.get("__dust", 0.0)
+                    hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+
+                    # Cast inputs to backend dtype
+                    hidden_states = hidden_states.to(requested_backends[0].dtype)
+                    assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
 
 
-                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+                    # parse deep prompts (optional argument)
+                    if prompts is None or is_dummy(prompts) or is_dummy(prompts):
+                        prompts = [DUMMY] * len(requested_backends)
+                    else:
+                        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
 
+                    if not (len(requested_backends) == len(prompts)):
+                        raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+                    length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
                     if prefix_length + length_increment > max_length:
                     if prefix_length + length_increment > max_length:
                         raise ValueError(
                         raise ValueError(
                             f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                             f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                             f" exceeds pre-allocated maximum {max_length}"
                             f" exceeds pre-allocated maximum {max_length}"
                         )
                         )
 
 
-                    # Cast inputs to backend dtype
-                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
-
                     # run request tensors through all requested modules, update caches
                     # run request tensors through all requested modules, update caches
-                    for backend, cache_handle in zip(requested_backends, cache_handles):
+                    for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                        if not is_dummy(prompt):
+                            hidden_states[:, : prompt.shape[1]] += prompt
+
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                        assert isinstance(
+                            hidden_states, torch.Tensor
+                        ), f"hidden states must be tensor, got {type(hidden_states)}"
                         assert (
                         assert (
-                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                            hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                        if isinstance(backend.inference_pool, PrioritizedTaskPool):
-                            hidden_states = await backend.inference_pool.submit_task(
-                                cache_metadata, *hidden_states, dust
-                            )
-                        else:
-                            hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
-                        assert isinstance(hidden_states, (list, tuple))
-                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                        assert isinstance(
+                            backend.inference_pool, PrioritizedTaskPool
+                        ), "petals support only prioritized pools"
+                        priority = self._prioritizer(
+                            cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
+                        )
+                        (hidden_states,) = await backend.inference_pool.submit_task(
+                            cache_metadata,
+                            hidden_states,
+                            hypo_ids,
+                            priority=priority,
+                            backend=backend,
+                            type="inference",
+                        )
 
 
                     # serialize and send last layer outputs
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
                         tensors=[
                             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
                             for result, proto in zip(
-                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
                             )
                             )
                         ]
                         ]
                     )
                     )
 
 
                     # prepare for next step
                     # prepare for next step
-                    prefix_length += hidden_states[0].shape[1]
+                    prefix_length += hidden_states.shape[1]
                     request = await (anext(requests))
                     request = await (anext(requests))
         finally:
         finally:
             print("CLOSED RPC_INFERENCE")
             print("CLOSED RPC_INFERENCE")
@@ -138,9 +170,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        dust = metadata.get("__dust", 0.0)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward should have number of points as number or None, got {points}"
 
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, dust=dust)
+        hidden_states = await _rpc_forward(
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
 
         # Serialize output and respond to client
         # Serialize output and respond to client
@@ -158,9 +195,13 @@ class TransformerConnectionHandler(ConnectionHandler):
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
 
         hidden_states = await _rpc_forward(
         hidden_states = await _rpc_forward(
-            *flat_inputs, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
         )
         )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
 
 
@@ -183,9 +224,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        dust = metadata.get("__dust", 0.0)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward should have number of points as number or None, got {points}"
 
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, dust=dust)
+        grads = await _rpc_backward(
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
 
 
         # Modify grad_inputs_schema to support grad_prompts
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -210,9 +256,13 @@ class TransformerConnectionHandler(ConnectionHandler):
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
 
         grads = await _rpc_backward(
         grads = await _rpc_backward(
-            *flat_tensors, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
         )
         )
 
 
         # Modify grad_inputs_schema to support grad_prompts
         # Modify grad_inputs_schema to support grad_prompts
@@ -267,7 +317,10 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
 
 
 async def _rpc_forward(
 async def _rpc_forward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: float = 0.0,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     """
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -277,27 +330,29 @@ async def _rpc_forward(
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
     """
-    hidden_states, *prompts = flat_tensors
+    hidden_states, prompts = flat_tensors
     dtype = requested_backends[0].dtype
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     hidden_states = hidden_states.to(dtype)
     assert hidden_states.ndim == 3
     assert hidden_states.ndim == 3
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
 
     # Run a chain of requested backends
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
         if not is_dummy(prompt):
-            hidden_states[:, :pre_seq_len] += prompt
-        if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, dust)
-        else:
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+            hidden_states[:, : prompt.shape[1]] += prompt
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            priority=priority,
+        )
         assert isinstance(hidden_states, torch.Tensor)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
         assert (
             hidden_states.ndim == 3
             hidden_states.ndim == 3
@@ -308,20 +363,20 @@ async def _rpc_forward(
 
 
 
 
 async def _rpc_backward(
 async def _rpc_backward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: float = 0.0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, *prompts = flat_tensors
+    inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
 
-    if not prompts or is_dummy(prompts[0]):
+    if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
         prompts = [DUMMY] * len(requested_backends)
-        pre_seq_len = 0
     else:
     else:
-        prompts = [prompts[0].to(requested_backends[0].dtype)]
-        prompts = [p.squeeze(0) for p in prompts[0].split(1)]
-        pre_seq_len = prompts[0].shape[-2]
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
 
     # Run a forward chain to collect intermediate inputs
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
     # Note that we do not forward for the last module since we do not need its output
@@ -329,31 +384,33 @@ async def _rpc_backward(
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
         if not is_dummy(prompt):
-            inputs[:, :pre_seq_len] += prompt
+            inputs[:, : prompt.shape[1]] += prompt
         inter_inputs.append(inputs)
         inter_inputs.append(inputs)
-
-        if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (inputs,) = await backend.forward_pool.submit_task(inputs, dust / 2.0)
-        else:
-            (inputs,) = await backend.forward_pool.submit_task(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
 
 
         assert isinstance(inputs, torch.Tensor)
         assert isinstance(inputs, torch.Tensor)
 
 
     if not is_dummy(prompts[-1]):
     if not is_dummy(prompts[-1]):
-        inputs[:, :pre_seq_len] += prompts[-1]
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
     inter_inputs.append(inputs)
     inter_inputs.append(inputs)
 
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     grad_prompts_reversed = []
     grad_prompts_reversed = []
     # Run a chain of requested backends
     # Run a chain of requested backends
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        if isinstance(backend.backward_pool, PrioritizedTaskPool):
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, dust / 2.0)
-        else:
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+
         assert isinstance(grad_outputs, torch.Tensor)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
         if not is_dummy(prompt):
-            grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape
     return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape

+ 3 - 2
src/server/task_prioritizer.py

@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 
 
+import torch
 from hivemind.moe.server.task_pool import Task
 from hivemind.moe.server.task_pool import Task
 
 
 
 
@@ -7,7 +8,7 @@ class TaskPrioritizerBase(ABC):
     """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
     """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
 
 
     @abstractmethod
     @abstractmethod
-    def prioritize(self, task: Task, points: float, *args, **kwargs) -> float:
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         """Evaluates task value by the amout of points given"""
         """Evaluates task value by the amout of points given"""
         pass
         pass
 
 
@@ -15,5 +16,5 @@ class TaskPrioritizerBase(ABC):
 class DummyTaskPrioritizer(TaskPrioritizerBase):
 class DummyTaskPrioritizer(TaskPrioritizerBase):
     """Simple implementation of DustBroker which counts amount of dust per task size"""
     """Simple implementation of DustBroker which counts amount of dust per task size"""
 
 
-    def __call__(self, task: Task, points: float, *args, **kwargs) -> float:
+    def __call__(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         return 0.0
         return 0.0