Browse Source

priortize task in handler before submit task

Pavel Samygin 3 năm trước cách đây
mục cha
commit
74abb1299e
2 tập tin đã thay đổi với 83 bổ sung35 xóa
  1. 80 33
      src/server/handler.py
  2. 3 2
      src/server/task_prioritizer.py

+ 80 - 33
src/server/handler.py

@@ -20,6 +20,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import PrioritizedTaskPool, TransformerBackend
+from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -28,11 +29,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     module_backends: Dict[ModuleUID, TransformerBackend]
 
-    def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
+    def __init__(
+        self,
+        dht: DHT,
+        module_backends: Dict[str, TransformerBackend],
+        inference_max_length: int,
+        task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+    ):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
             assert isinstance(module_backend, TransformerBackend)
         self.inference_max_length = inference_max_length
+        self._prioritizer = task_prioritizer
 
     async def _gather_inputs(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -69,13 +77,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             max_length = metadata.get("max_length")
+            points = metadata.get("__points", 0.0)
 
             if not requested_uids:
                 raise ValueError("User must specify at least one block for inference, but got none")
             assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
+            assert isinstance(
+                points, (float, int)
+            ), f"rpc_inference should have number of points as a number or None, got {points}"
             if not 0 <= max_length <= self.inference_max_length:
                 raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
 
+            point_per_piece = points / max_length if max_length > 0 else 0.0
             batch_size = request.tensors[0].size[0] if request.tensors else 1
 
             cache_metadata = torch.tensor(
@@ -86,9 +99,6 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-                    dust = metadata.get("__dust", 0.0)
-
                     hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
                     # Cast inputs to backend dtype
@@ -123,14 +133,20 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert (
                             hidden_states.ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                        if isinstance(backend.inference_pool, PrioritizedTaskPool):
-                            (hidden_states,) = await backend.inference_pool.submit_task(
-                                cache_metadata, hidden_states, hypo_ids, priority=dust
-                            )
-                        else:
-                            (hidden_states,) = await backend.inference_pool.submit_task(
-                                cache_metadata, hidden_states, hypo_ids
-                            )
+                        assert isinstance(
+                            backend.inference_pool, PrioritizedTaskPool
+                        ), "petals support only prioritized pools"
+                        priority = self._prioritizer(
+                            cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
+                        )
+                        (hidden_states,) = await backend.inference_pool.submit_task(
+                            cache_metadata,
+                            hidden_states,
+                            hypo_ids,
+                            priority=priority,
+                            backend=backend,
+                            type="inference",
+                        )
 
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
@@ -154,9 +170,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        dust = metadata.get("__dust", 0.0)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward should have number of points as number or None, got {points}"
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, dust=dust)
+        hidden_states = await _rpc_forward(
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize output and respond to client
@@ -174,9 +195,13 @@ class TransformerConnectionHandler(ConnectionHandler):
         uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
         hidden_states = await _rpc_forward(
-            *flat_inputs, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+            *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
         )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
 
@@ -199,9 +224,14 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        dust = metadata.get("__dust", 0.0)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward should have number of points as number or None, got {points}"
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, dust=dust)
+        grads = await _rpc_backward(
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+        )
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -226,9 +256,13 @@ class TransformerConnectionHandler(ConnectionHandler):
         uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+        points = metadata.get("__points", 0.0)
+        assert isinstance(
+            points, (float, int)
+        ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
         grads = await _rpc_backward(
-            *flat_tensors, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+            *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
         )
 
         # Modify grad_inputs_schema to support grad_prompts
@@ -283,7 +317,10 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
 async def _rpc_forward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: float = 0.0,
 ) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@@ -307,10 +344,15 @@ async def _rpc_forward(
     for backend, prompt in zip(requested_backends, prompts):
         if not is_dummy(prompt):
             hidden_states[:, : prompt.shape[1]] += prompt
-        if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, priority=dust)
-        else:
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            priority=priority,
+        )
         assert isinstance(hidden_states, torch.Tensor)
         assert (
             hidden_states.ndim == 3
@@ -321,7 +363,10 @@ async def _rpc_forward(
 
 
 async def _rpc_backward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    prioritizer: TaskPrioritizerBase,
+    points: float = 0.0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
     inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
@@ -341,11 +386,11 @@ async def _rpc_backward(
         if not is_dummy(prompt):
             inputs[:, : prompt.shape[1]] += prompt
         inter_inputs.append(inputs)
-
-        if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (inputs,) = await backend.forward_pool.submit_task(inputs, dust / 2.0)
-        else:
-            (inputs,) = await backend.forward_pool.submit_task(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
 
         assert isinstance(inputs, torch.Tensor)
 
@@ -357,10 +402,12 @@ async def _rpc_backward(
     grad_prompts_reversed = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        if isinstance(backend.backward_pool, PrioritizedTaskPool):
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, dust / 2.0)
-        else:
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))

+ 3 - 2
src/server/task_prioritizer.py

@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod
 
+import torch
 from hivemind.moe.server.task_pool import Task
 
 
@@ -7,7 +8,7 @@ class TaskPrioritizerBase(ABC):
     """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
 
     @abstractmethod
-    def prioritize(self, task: Task, points: float, *args, **kwargs) -> float:
+    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         """Evaluates task value by the amout of points given"""
         pass
 
@@ -15,5 +16,5 @@ class TaskPrioritizerBase(ABC):
 class DummyTaskPrioritizer(TaskPrioritizerBase):
     """Simple implementation of DustBroker which counts amount of dust per task size"""
 
-    def __call__(self, task: Task, points: float, *args, **kwargs) -> float:
+    def __call__(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         return 0.0