Browse Source

Merge branch 'main' into amd-gpus

Alexander Borzunov 2 năm trước cách đây
mục cha
commit
1ba721d51e

+ 7 - 7
README.md

@@ -8,7 +8,7 @@
     <br>
 </p>
 
-Generate text with distributed [LLaMA 2 (70B)](https://huggingface.co/meta-llama/Llama-2-70b-hf), [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2), [LLaMA-65B](https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md), [Guanaco-65B](https://huggingface.co/timdettmers/guanaco-65b) or [BLOOM-176B](https://huggingface.co/bigscience/bloom) and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **LLaMA 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 ```python
 from transformers import AutoTokenizer
@@ -37,7 +37,7 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
 
 🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
-💬 **Any questions?** Ping us in [our Discord](https://discord.gg/J29mCBNBvm)!
+💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
 ### Connect your GPU and increase Petals capacity
 
@@ -48,7 +48,7 @@ Petals is a community-run system &mdash; we rely on people sharing their GPUs. Y
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 pip install git+https://github.com/bigscience-workshop/petals
-python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16
+python -m petals.cli.run_server stabilityai/StableBeluga2
 ```
 
 🪟 **Windows + WSL.** Follow the guide on our [Wiki](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows).
@@ -57,7 +57,7 @@ python -m petals.cli.run_server stabilityai/StableBeluga2 --torch_dtype float16
 
 ```bash
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm learningathome/petals:main \
-    python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2 --torch_dtype float16
+    python -m petals.cli.run_server --port 31330 stabilityai/StableBeluga2
 ```
 
 These commands will host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your machine. You can also host `meta-llama/Llama-2-70b-hf`, `meta-llama/Llama-2-70b-chat-hf`, repos with LLaMA-65B, `bigscience/bloom`, `bigscience/bloomz`, and other compatible models from 🤗 [Model Hub](https://huggingface.co/models), or [add support](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) for new model architectures.
@@ -68,7 +68,7 @@ These commands will host a part of [Stable Beluga 2](https://huggingface.co/stab
 python -m petals.cli.run_server meta-llama/Llama-2-70b-chat-hf --token YOUR_TOKEN_HERE
 ```
 
-💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/D9MwApKgWa)!
+💬 **FAQ.** Check out our [Wiki](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server) to learn how to use multple GPUs, restart the server on reboot, etc. If you have any issues, ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
 
 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 
@@ -96,8 +96,8 @@ Learning more:
 
 ## How does it work?
 
-- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then team up with people serving the other parts to run inference or fine-tuning.
-- Single-batch inference runs at up to 6 steps/sec for LLaMA 2 (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough for [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
+- Petals runs large language models like [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
+- Single-batch inference runs at **up to 6 steps/sec** for **LLaMA 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
 - Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
 
 <p align="center">

+ 0 - 3
examples/prompt-tuning-sst2.ipynb

@@ -92,9 +92,6 @@
    },
    "outputs": [],
    "source": [
-    "# Choose a model you'd like to prompt-tune. We recommend starting with\n",
-    "# a smaller model (bigscience/bloom-7b1-petals) for faster prototyping.\n",
-    "# The code below uses LLaMA-65B.\n",
     "MODEL_NAME = \"enoch/llama-65b-hf\"\n",
     "\n",
     "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",

+ 14 - 2
src/petals/client/routing/sequence_manager.py

@@ -50,7 +50,7 @@ class SequenceManagerConfig:
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
     active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)
 
-    max_pinged: int = 5  # max servers to ping from each sequence side, per update
+    max_pinged: int = 3  # max servers to ping from each sequence side, per update
     ping_timeout: float = 2  # max time to wait for pings, per update
 
 
@@ -293,6 +293,8 @@ class RemoteSequenceManager:
         return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
 
     def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
+        client_server_rtts = self.ping_aggregator.to_dict()
+
         span_sequence = []
         current_index = start_index
         while current_index < end_index:
@@ -300,7 +302,13 @@ class RemoteSequenceManager:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
 
-            span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
+            # We choose longer servers to minimize the number of hops but leave some randomization
+            # to distribute the load. We also exclude servers known to be unreachable.
+            eps = 1e-6
+            span_weights = np.array(
+                [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
+                dtype=np.float64,
+            )
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
             assert chosen_span.start <= current_index < chosen_span.end
@@ -361,9 +369,13 @@ class RemoteSequenceManager:
             self.state.sequence_info.update_(new_block_infos)
 
             first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
+            middle_servers = [
+                span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
+            ]
             last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
 
         pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
+        pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
         pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
         self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
 

+ 195 - 0
src/petals/server/block_functions.py

@@ -0,0 +1,195 @@
+"""
+This module implements server-side computations on served blocks: forward, backward and inference; used by handler
+"""
+from __future__ import annotations
+
+from typing import AsyncIterator, Optional, Sequence, Tuple, Union
+
+import torch
+from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.moe.expert_uid import ExpertUID
+from hivemind.proto import runtime_pb2
+from hivemind.utils.nested import nested_flatten
+
+from petals.data_structures import InferenceMetadata
+from petals.server.backend import TransformerBackend
+from petals.server.memory_cache import Handle
+from petals.server.task_pool import PrioritizedTaskPool
+from petals.server.task_prioritizer import TaskPrioritizerBase
+from petals.utils.misc import DUMMY, is_dummy
+
+
+async def run_rpc_forward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> torch.Tensor:
+    """
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+
+    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    """
+    hidden_states, prompts = flat_tensors
+    dtype = requested_backends[0].dtype
+    # check parse input tensors and cast dtypes
+    hidden_states = hidden_states.to(dtype)
+    assert hidden_states.ndim == 3
+    if prompts is None or is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+    # Run a chain of requested backends
+    for backend, prompt in zip(requested_backends, prompts):
+        if not is_dummy(prompt):
+            hidden_states[:, : prompt.shape[1]] += prompt
+
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+        )
+        (hidden_states,) = await backend.forward_pool.submit_task(
+            hidden_states,
+            active_adapter,
+            priority=priority,
+        )
+        assert isinstance(hidden_states, torch.Tensor)
+        assert (
+            hidden_states.ndim == 3
+        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+    return hidden_states
+
+
+async def run_rpc_backward(
+    *flat_tensors: torch.Tensor,
+    requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
+    prioritizer: TaskPrioritizerBase,
+    points: int = 0,
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+    inputs, grad_outputs, prompts = flat_tensors
+    # Cast inputs & grad outputs to backend dtype
+    inputs = inputs.to(requested_backends[0].dtype)
+    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+    if prompts is None or is_dummy(prompts):
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+    # Run a forward chain to collect intermediate inputs
+    # Note that we do not forward for the last module since we do not need its output
+    inter_inputs = []
+    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+        if not is_dummy(prompt):
+            inputs[:, : prompt.shape[1]] += prompt
+        inter_inputs.append(inputs)
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+        )
+        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
+
+        assert isinstance(inputs, torch.Tensor)
+
+    if not is_dummy(prompts[-1]):
+        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
+    inter_inputs.append(inputs)
+
+    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+    grad_prompts_reversed = []
+    # Run a chain of requested backends
+    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+        priority = prioritizer.prioritize(
+            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+        )
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
+
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
+
+    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
+    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape
+
+
+async def iterate_rpc_inference(
+    requested_uids: Sequence[ExpertUID],
+    requested_backends: Sequence[TransformerBackend],
+    active_adapter: Optional[str],
+    input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
+    cache_handles: Sequence[Sequence[Handle]],
+    max_length: int,
+    prioritizer: TaskPrioritizerBase,
+    points: int,
+) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
+    assert len(cache_handles) == len(requested_backends)
+
+    prefix_length = 0
+    point_per_piece = points / max_length if max_length > 0 else 0.0
+
+    async for request, step_metadata in input_iterator:
+        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+
+        # Cast inputs to backend dtype
+        hidden_states = hidden_states.to(requested_backends[0].dtype)
+        assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
+        # parse deep prompts (optional argument)
+        has_prompts = prompts is not None and not is_dummy(prompts)
+        if not has_prompts:
+            prompts = [None] * len(requested_backends)
+        else:
+            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
+
+        if not (len(requested_backends) == len(prompts)):
+            raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
+        if prefix_length + length_increment > max_length:
+            raise ValueError(
+                f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                f" exceeds pre-allocated maximum {max_length}"
+            )
+
+        priority = prioritizer.prioritize(
+            hidden_states,
+            hypo_ids,
+            points=point_per_piece,
+            requested_uids=requested_uids,
+            type="inference",
+        )
+
+        inference_infos = tuple(
+            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
+            for uid, handles in zip(requested_uids, cache_handles)
+        )
+
+        if hidden_states.numel() == 0:
+            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
+            # when user wants to pre-allocate cache or check that server *can* allocate that cache
+        else:
+            assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
+            (hidden_states,) = await requested_backends[0].inference_pool.submit_task(
+                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
+            )
+
+        # serialize and send last layer outputs
+        output_tensors = [
+            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+        ]
+        can_push = not has_prompts
+        yield output_tensors, can_push
+
+        # prepare for next step
+        prefix_length += length_increment

+ 2 - 1
src/petals/server/block_utils.py

@@ -11,7 +11,8 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
     if dtype not in ("auto", None):
         return dtype
-    if config.torch_dtype not in ("auto", None):
+    if config.torch_dtype not in ("auto", None, torch.float32):
+        # If config specifies float32, we override it to the default dtype below
         return config.torch_dtype
     return torch.bfloat16
 

+ 19 - 173
src/petals/server/handler.py

@@ -6,7 +6,7 @@ import multiprocessing as mp
 import sys
 from enum import Enum
 from itertools import chain
-from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple
 
 import torch
 from async_timeout import timeout
@@ -29,12 +29,11 @@ from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 
 import petals
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, InferenceMetadata, ModuleUID
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
 from petals.server.backend import TransformerBackend
+from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
 from petals.server.memory_cache import Handle
-from petals.server.task_pool import PrioritizedTaskPool
 from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
-from petals.utils.misc import DUMMY, is_dummy
 
 logger = get_logger(__name__)
 
@@ -147,7 +146,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
-                active_adapter = self._get_active_adapter(metadata)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
                 if not requested_uids:
@@ -163,78 +161,28 @@ class TransformerConnectionHandler(ConnectionHandler):
                         f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
                     )
 
-                point_per_piece = points / max_length if max_length > 0 else 0.0
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
-                prefix_length = 0
 
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
-                    assert len(cache_handles) == len(requested_backends)
-                    first_request = request
                     background_tasks = set()
-                    async for request, metadata in self._iterate_inference_steps(
-                        first_request, requests, session_id, requested_uids, context
+                    async for output_tensors, can_push in iterate_rpc_inference(
+                        requested_uids=requested_uids,
+                        requested_backends=requested_backends,
+                        active_adapter=self._get_active_adapter(metadata),
+                        input_iterator=self._iterate_inference_steps(
+                            request, requests, session_id, requested_uids, context
+                        ),
+                        cache_handles=cache_handles,
+                        max_length=max_length,
+                        prioritizer=self._prioritizer,
+                        points=points,
                     ):
-                        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
-
-                        # Cast inputs to backend dtype
-                        hidden_states = hidden_states.to(requested_backends[0].dtype)
-                        assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
-
-                        # parse deep prompts (optional argument)
-                        has_prompts = prompts is not None and not is_dummy(prompts)
-                        if not has_prompts:
-                            prompts = [None] * len(requested_backends)
-                        else:
-                            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-                            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
-
-                        if not (len(requested_backends) == len(prompts)):
-                            raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
-
-                        length_increment = hidden_states.shape[1]  # how many tokens are added this step (in each seq)
-                        if prefix_length + length_increment > max_length:
-                            raise ValueError(
-                                f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
-                                f" exceeds pre-allocated maximum {max_length}"
-                            )
-
-                        priority = self._prioritizer.prioritize(
-                            hidden_states,
-                            hypo_ids,
-                            points=point_per_piece,
-                            requested_uids=requested_uids,
-                            type="inference",
-                        )
-
-                        inference_infos = tuple(
-                            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
-                            for uid, handles in zip(requested_uids, cache_handles)
-                        )
-
-                        if hidden_states.numel() == 0:
-                            pass  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
-                            # when user wants to pre-allocate cache or check that server *can* allocate that cache
-                        else:
-                            assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
-                            (hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
-                                hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
-                            )
-
-                        # serialize and send last layer outputs
-                        output_tensors = [
-                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                            for result, proto in zip(
-                                (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
-                            )
-                        ]
-                        if not has_prompts:
+                        if can_push:
                             task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
                             background_tasks.add(task)  # Keep reference until it is done to save it from GC
                             task.add_done_callback(background_tasks.discard)
                         yield runtime_pb2.ExpertResponse(tensors=output_tensors)
 
-                        # prepare for next step
-                        prefix_length += length_increment
             finally:
                 self._log_request("rpc_inference.close", requested_uids, context)
 
@@ -408,7 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
-            hidden_states = await _rpc_forward(
+            hidden_states = await run_rpc_forward(
                 *flat_inputs,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -435,7 +383,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
-            hidden_states = await _rpc_forward(
+            hidden_states = await run_rpc_forward(
                 *flat_inputs,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -486,7 +434,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
-            grads = await _rpc_backward(
+            grads = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -511,7 +459,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
-            grads = await _rpc_backward(
+            grads = await run_rpc_backward(
                 *flat_tensors,
                 requested_backends=requested_backends,
                 prioritizer=self._prioritizer,
@@ -621,105 +569,3 @@ class TransformerConnectionHandler(ConnectionHandler):
             result.update(block_info)
 
         return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
-
-
-async def _rpc_forward(
-    *flat_tensors: torch.Tensor,
-    requested_backends: Sequence[TransformerBackend],
-    active_adapter: str = "",
-    prioritizer: TaskPrioritizerBase,
-    points: int = 0,
-) -> torch.Tensor:
-    """
-    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
-
-    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
-    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
-    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
-    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
-    """
-    hidden_states, prompts = flat_tensors
-    dtype = requested_backends[0].dtype
-    # check parse input tensors and cast dtypes
-    hidden_states = hidden_states.to(dtype)
-    assert hidden_states.ndim == 3
-    if prompts is None or is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
-    else:
-        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-
-    # Run a chain of requested backends
-    for backend, prompt in zip(requested_backends, prompts):
-        if not is_dummy(prompt):
-            hidden_states[:, : prompt.shape[1]] += prompt
-
-        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
-        priority = prioritizer.prioritize(
-            hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
-        )
-        (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states,
-            active_adapter,
-            priority=priority,
-        )
-        assert isinstance(hidden_states, torch.Tensor)
-        assert (
-            hidden_states.ndim == 3
-        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
-
-    return hidden_states
-
-
-async def _rpc_backward(
-    *flat_tensors: torch.Tensor,
-    requested_backends: Sequence[TransformerBackend],
-    active_adapter: str = "",
-    prioritizer: TaskPrioritizerBase,
-    points: int = 0,
-) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, prompts = flat_tensors
-    # Cast inputs & grad outputs to backend dtype
-    inputs = inputs.to(requested_backends[0].dtype)
-    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
-
-    if prompts is None or is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
-    else:
-        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
-
-    # Run a forward chain to collect intermediate inputs
-    # Note that we do not forward for the last module since we do not need its output
-    inter_inputs = []
-    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
-        assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-        if not is_dummy(prompt):
-            inputs[:, : prompt.shape[1]] += prompt
-        inter_inputs.append(inputs)
-        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
-        priority = prioritizer.prioritize(
-            inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
-        )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
-
-        assert isinstance(inputs, torch.Tensor)
-
-    if not is_dummy(prompts[-1]):
-        inputs[:, : prompts[-1].shape[1]] += prompts[-1]
-    inter_inputs.append(inputs)
-
-    assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
-    grad_prompts_reversed = []
-    # Run a chain of requested backends
-    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
-        assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
-        priority = prioritizer.prioritize(
-            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
-        )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
-
-        assert isinstance(grad_outputs, torch.Tensor)
-        if not is_dummy(prompt):
-            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
-
-    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
-    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape

+ 3 - 1
src/petals/server/server.py

@@ -690,7 +690,9 @@ class ModuleAnnouncerThread(threading.Thread):
 
             delay = self.update_period - (time.perf_counter() - start_time)
             if delay < 0:
-                logger.warning("Declaring blocs to DHT takes more than --update_period, consider increasing it")
+                logger.warning(
+                    f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
+                )
             self.trigger.wait(max(delay, 0))
             self.trigger.clear()