Jelajahi Sumber

Prioritize short inference, unmerge pools for long inference (#458)

Right now, long inference requests may occupy Runtime for a few seconds without giving it away to process short (most latency-sensitive requests). This PR fixes it by disallowing the merged pool for long requests and prioritizing the short ones.
Alexander Borzunov 2 tahun lalu
induk
melakukan
056f22515a

+ 30 - 14
src/petals/server/block_functions.py

@@ -16,8 +16,15 @@ from petals.server.backend import TransformerBackend
 from petals.server.memory_cache import Handle
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import TaskPrioritizerBase
+from petals.utils.convert_block import QuantType
 from petals.utils.misc import DUMMY, is_dummy
 
+# We prioritize short inference requests and make them use a *merged* inference pool,
+# so they are processed without interruptions and extra overheads
+# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
+MAX_SHORT_INFERENCE_TOKENS = 128
+MAX_NF4_SHORT_INFERENCE_TOKENS = 1
+
 
 async def run_rpc_forward(
     *flat_tensors: torch.Tensor,
@@ -127,9 +134,11 @@ async def iterate_rpc_inference(
     active_adapter: Optional[str],
     input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
     cache_handles: Sequence[Sequence[Handle]],
+    *,
     max_length: int,
     prioritizer: TaskPrioritizerBase,
     points: int,
+    quant_type: QuantType,
 ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
     assert len(cache_handles) == len(requested_backends)
 
@@ -138,6 +147,7 @@ async def iterate_rpc_inference(
 
     async for request, step_metadata in input_iterator:
         hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+        batch_size, length_increment, _ = hidden_states.shape
 
         # Cast inputs to backend dtype
         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -154,34 +164,40 @@ async def iterate_rpc_inference(
         if not (len(requested_backends) == len(prompts)):
             raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
 
-        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
         if prefix_length + length_increment > max_length:
             raise ValueError(
                 f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                 f" exceeds pre-allocated maximum {max_length}"
             )
 
+        merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
+        can_merge_pools = batch_size * length_increment <= merge_max_tokens
         priority = prioritizer.prioritize(
             hidden_states,
             hypo_ids,
             points=point_per_piece,
             requested_uids=requested_uids,
-            type="inference",
-        )
-
-        inference_infos = tuple(
-            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
-            for uid, handles in zip(requested_uids, cache_handles)
+            type="short_inference" if can_merge_pools else "inference",
         )
 
-        if hidden_states.numel() == 0:
-            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
-            # when user wants to pre-allocate cache or check that server *can* allocate that cache
-        else:
+        # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
+        # when user wants to pre-allocate cache or check that server *can* allocate that cache.
+        if hidden_states.numel() > 0:
             assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
-            (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
-                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
-            )
+            if can_merge_pools:
+                inference_infos = tuple(
+                    InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
+                    for uid, handles in zip(requested_uids, cache_handles)
+                )
+                (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
+                    hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
+                )
+            else:
+                for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
+                    inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
+                    (hidden_states,) = await backend.inference_pool.submit_task(
+                        hidden_states, hypo_ids, inference_infos, prompt, priority=priority
+                    )
 
         # serialize and send last layer outputs
         output_tensors = [

+ 4 - 0
src/petals/server/handler.py

@@ -34,6 +34,7 @@ from petals.server.backend import TransformerBackend
 from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
 from petals.server.memory_cache import Handle
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
+from petals.utils.convert_block import QuantType
 
 logger = get_logger(__name__)
 
@@ -71,6 +72,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         session_timeout: float,
         step_timeout: float,
         task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+        quant_type: QuantType,
     ):
         super().__init__(dht, module_backends)
         for module_backend in self.module_backends.values():
@@ -88,6 +90,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self._prioritizer = task_prioritizer
+        self.quant_type = quant_type
 
     async def add_p2p_handlers(self, *args, **kwargs) -> None:
         if self._listener_task is None:
@@ -176,6 +179,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         max_length=max_length,
                         prioritizer=self._prioritizer,
                         points=points,
+                        quant_type=self.quant_type,
                     ):
                         if can_push:
                             task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))

+ 1 - 0
src/petals/server/server.py

@@ -560,6 +560,7 @@ class ModuleContainer(threading.Thread):
                 request_timeout=request_timeout,
                 session_timeout=session_timeout,
                 step_timeout=step_timeout,
+                quant_type=QuantType[server_info.quant_type.upper()],
             )
             for i in range(num_handlers)
         ]

+ 5 - 4
src/petals/server/task_prioritizer.py

@@ -13,9 +13,10 @@ class TaskPrioritizerBase(ABC):
 
 
 class DummyTaskPrioritizer(TaskPrioritizerBase):
-    """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
-
     def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+        # Inference steps (especially short ones) go first since they are more latency-sensitive
+        if kwargs.get("type") == "short_inference":
+            return 1.0
         if kwargs.get("type") == "inference":
-            return 1.0  # inference steps go first since they are more latency-sensitive
-        return 2.0  # forward, backward
+            return 2.0
+        return 3.0  # Forward, backward

+ 23 - 18
tests/test_block_exact_match.py

@@ -4,6 +4,7 @@ import pytest
 import torch
 
 from petals import AutoDistributedConfig, RemoteSequential
+from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
 from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 
@@ -13,26 +14,30 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
     config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     remote_sequential = RemoteSequential(config)
 
-    for block_index in random.sample(range(config.num_hidden_layers), 3):
-        remote_block = remote_sequential[block_index]
+    block_index = random.randint(0, config.num_hidden_layers - 1)
+    remote_block = remote_sequential[block_index]
 
-        inputs = torch.randn(1, 8, config.hidden_size)
-        outputs_forward = remote_block(inputs)
+    inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
+    outputs_forward = remote_block(inputs)
 
-        outputs_inference = []
-        with torch.inference_mode():
-            with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
-                for i in range(inputs.shape[1]):
-                    outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+    outputs_inference = []
+    with torch.inference_mode():
+        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+            # Test long inference (unmerged inference pools)
+            outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
 
-                # test that max length is respected
-                with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
-                    sess.step(inputs[:, -1:, :])
-                assert "Maximum length exceeded" in repr(exc_info.value)
-        outputs_inference = torch.cat(outputs_inference, dim=1)
+            # Test short inference (merged inference pools)
+            for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
+                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
 
-        ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
-        (outputs_local,) = ref_block(inputs)
+            # test that max length is respected
+            with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
+                sess.step(inputs[:, -1:, :])
+            assert "Maximum length exceeded" in repr(exc_info.value)
+    outputs_inference = torch.cat(outputs_inference, dim=1)
 
-        assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
-        assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
+    ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+    (outputs_local,) = ref_block(inputs)
+
+    assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
+    assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

+ 1 - 1
tests/test_remote_sequential.py

@@ -40,7 +40,7 @@ def test_remote_sequential():
     assert hidden.shape == test_inputs.shape
     assert hidden.requires_grad
     second_half_outputs = second_half(hidden)
-    assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4)
+    assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
 
     (second_half_outputs * grad_proj).sum().backward()
     assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)