Procházet zdrojové kódy

Merge inference pools into one to increase inference speed (#225)

It turns out using a separate pool for each block has led to significant slowdown, see #224 for details.
justheuristic před 2 roky
rodič
revize
c4938bc23e

+ 2 - 2
.github/workflows/check-style.yaml

@@ -9,7 +9,7 @@ jobs:
   black:
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - uses: psf/black@stable
         with:
           options: "--check --diff"
@@ -17,7 +17,7 @@ jobs:
   isort:
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - uses: actions/setup-python@v2
         with:
           python-version: 3.8

+ 1 - 1
.github/workflows/push-docker-image.yaml

@@ -14,7 +14,7 @@ jobs:
 
     steps:
       - name: Checkout
-        uses: actions/checkout@v2
+        uses: actions/checkout@v3
 
       - name: Docker meta
         id: meta

+ 2 - 2
.github/workflows/run-tests.yaml

@@ -13,7 +13,7 @@ jobs:
     timeout-minutes: 15
     steps:
       - name: Checkout
-        uses: actions/checkout@v2
+        uses: actions/checkout@v3
       - name: Check if the model is cached
         id: cache-model
         uses: actions/cache@v3
@@ -64,7 +64,7 @@ jobs:
     timeout-minutes: 15
     steps:
       - name: Checkout
-        uses: actions/checkout@v2
+        uses: actions/checkout@v3
       - name: Set up Python
         uses: actions/setup-python@v2
         with:

+ 2 - 0
src/petals/data_structures.py

@@ -6,6 +6,7 @@ from enum import Enum
 from typing import Any, Dict, Tuple
 
 from hivemind import PeerID
+from hivemind.moe.expert_uid import ExpertUID
 
 from petals.server.memory_cache import Handle
 
@@ -48,5 +49,6 @@ RPCInfo = Dict[str, Any]
 
 @dataclasses.dataclass(frozen=True)
 class InferenceMetadata:
+    uid: ExpertUID
     prefix_length: int
     cache_handles: Tuple[Handle, ...]

+ 48 - 13
src/petals/server/backend.py

@@ -3,10 +3,11 @@ from __future__ import annotations
 
 from collections import Counter
 from itertools import chain
-from typing import Any, Dict, Sequence, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 import torch
 from hivemind import BatchTensorDescriptor, TensorDescriptor
+from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 from tensor_parallel import TensorParallel
@@ -15,7 +16,7 @@ from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomAttention
 
 from petals.data_structures import InferenceMetadata
-from petals.server.memory_cache import MemoryCache
+from petals.server.memory_cache import Handle, MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
 
@@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend):
         device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
             self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
-        )
+        )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         self.forward_pool = PrioritizedTaskPool(
             self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
         )
@@ -79,22 +80,20 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
+    @torch.inference_mode()
     def inference_step(
         self,
         hidden_states: torch.Tensor,
         hypo_ids: torch.LongTensor,
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
-        with torch.inference_mode():
-            assert (
-                hidden_states.ndim == 3
-            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-            with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
-                self._reorder_cache_inplace(cache_tensors, hypo_ids)
-                layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
-                hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
-                self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
-                return (hidden_states,)
+        assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+        with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
+            self._reorder_cache_inplace(cache_tensors, hypo_ids)
+            layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
+            hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+            self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
+            return (hidden_states,)
 
     def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
         """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
@@ -139,3 +138,39 @@ class TransformerBackend(ModuleBackend):
         dummy = torch.tensor([])
         for p in self.module.parameters():
             p.data = dummy
+
+
+def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
+    """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
+    assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
+    first_pool = next(iter(backends.values())).inference_pool
+    merged_pool = PrioritizedTaskPool(
+        _MergedInferenceStep(backends),
+        max_batch_size=first_pool.max_batch_size,
+        device=first_pool.device,
+        name=f"merged_inference",
+    )
+    for backend in backends.values():
+        assert not backend.inference_pool.is_alive()
+        backend.inference_pool = merged_pool
+
+
+class _MergedInferenceStep:
+    def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
+        self.backends = backends
+
+    def __call__(
+        self,
+        hidden_states: torch.Tensor,
+        hypo_ids: torch.LongTensor,
+        inference_infos: Sequence[InferenceMetadata],
+        *optional_prompts: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, ...]:
+        assert len(inference_infos) == len(
+            optional_prompts
+        ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
+        for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
+            if optional_prompt is not None:
+                hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
+            (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
+        return (hidden_states,)

+ 23 - 30
src/petals/server/handler.py

@@ -141,10 +141,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
 
                         # parse deep prompts (optional argument)
-                        if prompts is None or is_dummy(prompts) or is_dummy(prompts):
-                            prompts = [DUMMY] * len(requested_backends)
+                        if prompts is None or is_dummy(prompts):
+                            prompts = [None] * len(requested_backends)
                         else:
                             prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+                            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
 
                         if not (len(requested_backends) == len(prompts)):
                             raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
@@ -156,33 +157,26 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 f" exceeds pre-allocated maximum {max_length}"
                             )
 
-                        # run request tensors through all requested modules, update caches
-                        for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
-                            if not is_dummy(prompt):
-                                hidden_states[:, : prompt.shape[1]] += prompt
-                            if hidden_states.numel() == 0:
-                                continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
-                                # when user wants to pre-allocate cache or check that server *can* allocate that cache
-
-                            metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
-                            assert isinstance(
-                                hidden_states, torch.Tensor
-                            ), f"hidden states must be tensor, got {type(hidden_states)}"
-                            assert (
-                                hidden_states.ndim == 3
-                            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                            assert isinstance(
-                                backend.inference_pool, PrioritizedTaskPool
-                            ), "petals support only prioritized pools"
-                            priority = self._prioritizer.prioritize(
-                                hidden_states,
-                                hypo_ids,
-                                points=point_per_piece / len(requested_backends),
-                                backend=backend,
-                                type="inference",
-                            )
-                            (hidden_states,) = await backend.inference_pool.submit_task(
-                                hidden_states, hypo_ids, metadata, priority=priority
+                        priority = self._prioritizer.prioritize(
+                            hidden_states,
+                            hypo_ids,
+                            points=point_per_piece,
+                            requested_uids=requested_uids,
+                            type="inference",
+                        )
+
+                        inference_infos = tuple(
+                            InferenceMetadata(uid, prefix_length, tuple(handles))
+                            for uid, handles in zip(requested_uids, cache_handles)
+                        )
+
+                        if hidden_states.numel() == 0:
+                            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
+                            # when user wants to pre-allocate cache or check that server *can* allocate that cache
+                        else:
+                            assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
+                            (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
+                                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
                             )
 
                         # serialize and send last layer outputs
@@ -444,7 +438,6 @@ async def _rpc_forward(
             hidden_states.ndim == 3
         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
-    # Serialize the overall output
     return hidden_states
 
 

+ 12 - 4
src/petals/server/server.py

@@ -22,7 +22,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
-from petals.server.backend import TransformerBackend
+from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
@@ -453,11 +453,12 @@ class ModuleContainer(threading.Thread):
             joining_announcer.stop.set()
             joining_announcer.join()
 
+        merge_inference_pools_inplace(blocks)
+
         return cls(
             dht,
             blocks,
             throughput=throughput,
-            device=device,
             update_period=update_period,
             expiration=expiration,
             **kwargs,
@@ -476,7 +477,6 @@ class ModuleContainer(threading.Thread):
         request_timeout: float,
         session_timeout: float,
         step_timeout: float,
-        device: Union[str, torch.device],
         start: bool,
         **kwargs,
     ):
@@ -495,7 +495,7 @@ class ModuleContainer(threading.Thread):
             )
             for _ in range(num_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, device=None, **kwargs)
+        self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         self.online_announcer = ModuleAnnouncerThread(
             list(self.module_backends.keys()),
@@ -633,3 +633,11 @@ class ModuleAnnouncerThread(threading.Thread):
             )
             if self.stop.wait(self.update_period):
                 break
+
+
+class RuntimeWithDeduplicatedPools(Runtime):
+    """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.pools = tuple(set(self.pools))

+ 3 - 1
src/petals/server/task_prioritizer.py

@@ -16,4 +16,6 @@ class DummyTaskPrioritizer(TaskPrioritizerBase):
     """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
 
     def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
-        return 0.0
+        if kwargs.get("type") == "inference":
+            return 1.0  # inference steps go first since they are more latency-sensitive
+        return 2.0  # forward, backward