Browse Source

benchmark

Just Heuristic 2 years ago
parent
commit
c2e3c13241

+ 1 - 20
src/petals/bloom/from_pretrained.py

@@ -45,26 +45,7 @@ def load_pretrained_block(
         cache_dir = DEFAULT_CACHE_DIR
 
     block = WrappedBloomBlock(config)
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        use_auth_token=use_auth_token,
-        cache_dir=cache_dir,
-        max_disk_space=max_disk_space,
-    )
-
-    if torch_dtype == "auto":
-        with torch.no_grad():
-            for name, param in block.named_parameters():
-                assert name in state_dict, f"{name} not in state dict"
-                param.data = param.data.to(state_dict[name].dtype)
-    else:
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        block = block.to(dtype=torch_dtype)
-
-    report = block.load_state_dict(state_dict, strict=True)
-    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
+    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, DEBUG NOT ACTUAL WEIGHTS!")
     return block
 
 

+ 1 - 5
src/petals/constants.py

@@ -1,9 +1,5 @@
 PUBLIC_INITIAL_PEERS = [
-    "/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
-    "/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
-    "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
-    "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
 ]
 
 # The reachability API is currently used only when connecting to the public swarm
-REACHABILITY_API_URL = "http://health.petals.ml"
+REACHABILITY_API_URL = "REMOVED"

+ 2 - 0
src/petals/data_structures.py

@@ -6,6 +6,7 @@ from enum import Enum
 from typing import Any, Dict, Tuple
 
 from hivemind import PeerID
+from hivemind.moe.expert_uid import ExpertUID
 
 from petals.server.memory_cache import Handle
 
@@ -48,5 +49,6 @@ RPCInfo = Dict[str, Any]
 
 @dataclasses.dataclass(frozen=True)
 class InferenceMetadata:
+    uid: ExpertUID
     prefix_length: int
     cache_handles: Tuple[Handle, ...]

+ 48 - 13
src/petals/server/backend.py

@@ -3,10 +3,11 @@ from __future__ import annotations
 
 from collections import Counter
 from itertools import chain
-from typing import Any, Dict, Sequence, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple
 
 import torch
 from hivemind import BatchTensorDescriptor, TensorDescriptor
+from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 from tensor_parallel import TensorParallel
@@ -15,7 +16,7 @@ from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomAttention
 
 from petals.data_structures import InferenceMetadata
-from petals.server.memory_cache import MemoryCache
+from petals.server.memory_cache import Handle, MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
 
@@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend):
         device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
             self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
-        )
+        )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         self.forward_pool = PrioritizedTaskPool(
             self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
         )
@@ -69,6 +70,21 @@ class TransformerBackend(ModuleBackend):
         for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
             self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
 
+    @staticmethod
+    def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
+        """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
+        assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
+        first_pool = next(iter(backends.values())).inference_pool
+        merged_pool = PrioritizedTaskPool(
+            _MergedInferenceStep(backends),
+            max_batch_size=first_pool.max_batch_size,
+            device=first_pool.device,
+            name=f"merged_inference",
+        )
+        for backend in backends.values():
+            assert not backend.inference_pool.is_alive()
+            backend.inference_pool = merged_pool
+
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""
         head_dim = self.config.hidden_size // self.config.n_head
@@ -79,22 +95,20 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
+    @torch.inference_mode()
     def inference_step(
         self,
         hidden_states: torch.Tensor,
         hypo_ids: torch.LongTensor,
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
-        with torch.inference_mode():
-            assert (
-                hidden_states.ndim == 3
-            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-            with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
-                self._reorder_cache_inplace(cache_tensors, hypo_ids)
-                layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
-                hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
-                self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
-                return (hidden_states,)
+        assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+        with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
+            self._reorder_cache_inplace(cache_tensors, hypo_ids)
+            layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
+            hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+            self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
+            return (hidden_states,)
 
     def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
         """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
@@ -139,3 +153,24 @@ class TransformerBackend(ModuleBackend):
         dummy = torch.tensor([])
         for p in self.module.parameters():
             p.data = dummy
+
+
+class _MergedInferenceStep:
+    def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
+        self.backends = backends
+
+    def __call__(
+        self,
+        hidden_states: torch.Tensor,
+        hypo_ids: torch.LongTensor,
+        inference_infos: Sequence[InferenceMetadata],
+        *optional_prompts: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, ...]:
+        assert len(inference_infos) == len(
+            optional_prompts
+        ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
+        for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
+            if optional_prompt is not None:
+                hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
+            (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
+        return (hidden_states,)

+ 23 - 30
src/petals/server/handler.py

@@ -141,10 +141,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                         assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
 
                         # parse deep prompts (optional argument)
-                        if prompts is None or is_dummy(prompts) or is_dummy(prompts):
-                            prompts = [DUMMY] * len(requested_backends)
+                        if prompts is None or is_dummy(prompts):
+                            prompts = [None] * len(requested_backends)
                         else:
                             prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+                            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
 
                         if not (len(requested_backends) == len(prompts)):
                             raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
@@ -156,33 +157,26 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 f" exceeds pre-allocated maximum {max_length}"
                             )
 
-                        # run request tensors through all requested modules, update caches
-                        for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
-                            if not is_dummy(prompt):
-                                hidden_states[:, : prompt.shape[1]] += prompt
-                            if hidden_states.numel() == 0:
-                                continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
-                                # when user wants to pre-allocate cache or check that server *can* allocate that cache
-
-                            metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
-                            assert isinstance(
-                                hidden_states, torch.Tensor
-                            ), f"hidden states must be tensor, got {type(hidden_states)}"
-                            assert (
-                                hidden_states.ndim == 3
-                            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-                            assert isinstance(
-                                backend.inference_pool, PrioritizedTaskPool
-                            ), "petals support only prioritized pools"
-                            priority = self._prioritizer.prioritize(
-                                hidden_states,
-                                hypo_ids,
-                                points=point_per_piece / len(requested_backends),
-                                backend=backend,
-                                type="inference",
-                            )
-                            (hidden_states,) = await backend.inference_pool.submit_task(
-                                hidden_states, hypo_ids, metadata, priority=priority
+                        priority = self._prioritizer.prioritize(
+                            hidden_states,
+                            hypo_ids,
+                            points=point_per_piece,
+                            requested_uids=requested_uids,
+                            type="inference",
+                        )
+
+                        inference_infos = tuple(
+                            InferenceMetadata(uid, prefix_length, tuple(handles))
+                            for uid, handles in zip(requested_uids, cache_handles)
+                        )
+
+                        if hidden_states.numel() == 0:
+                            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
+                            # when user wants to pre-allocate cache or check that server *can* allocate that cache
+                        else:
+                            assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
+                            (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
+                                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
                             )
 
                         # serialize and send last layer outputs
@@ -444,7 +438,6 @@ async def _rpc_forward(
             hidden_states.ndim == 3
         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
-    # Serialize the overall output
     return hidden_states
 
 

+ 11 - 3
src/petals/server/server.py

@@ -453,11 +453,12 @@ class ModuleContainer(threading.Thread):
             joining_announcer.stop.set()
             joining_announcer.join()
 
+        TransformerBackend.merge_inference_pools_inplace(blocks)
+
         return cls(
             dht,
             blocks,
             throughput=throughput,
-            device=device,
             update_period=update_period,
             expiration=expiration,
             **kwargs,
@@ -476,7 +477,6 @@ class ModuleContainer(threading.Thread):
         request_timeout: float,
         session_timeout: float,
         step_timeout: float,
-        device: Union[str, torch.device],
         start: bool,
         **kwargs,
     ):
@@ -495,7 +495,7 @@ class ModuleContainer(threading.Thread):
             )
             for _ in range(num_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, device=None, **kwargs)
+        self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
         # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
         self.online_announcer = ModuleAnnouncerThread(
             list(self.module_backends.keys()),
@@ -633,3 +633,11 @@ class ModuleAnnouncerThread(threading.Thread):
             )
             if self.stop.wait(self.update_period):
                 break
+
+
+class RuntimeWithDeduplicatedPools(Runtime):
+    """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.pools = tuple(set(self.pools))

+ 3 - 1
src/petals/server/task_prioritizer.py

@@ -16,4 +16,6 @@ class DummyTaskPrioritizer(TaskPrioritizerBase):
     """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
 
     def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
-        return 0.0
+        if kwargs.get("type") == "inference":
+            return 1.0  # inference steps go first since they are more latency-sensitive
+        return 2.0  # forward, backward