Просмотр исходного кода

Merge remote-tracking branch 'origin/main' into forward_kwargs

# Conflicts:
#	src/petals/__init__.py
#	src/petals/client/inference_session.py
Your Name 1 год назад
Родитель
Сommit
3195579620

+ 12 - 13
.github/workflows/run-tests.yaml

@@ -48,7 +48,6 @@ jobs:
           export MODEL_NAME="${{ matrix.model }}"
           export REF_NAME="${{ matrix.model }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
-          export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
 
           # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 
@@ -61,27 +60,25 @@ jobs:
 
           until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
-            --mean_balance_check_period 10 \
-            --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
+          export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
+            --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
+          export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
+
+          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
           SERVER1_PID=$!
           # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
 
           sleep 10  # wait for the 1st server to choose blocks
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
-            --identity_path tests/server2.id \
-            --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
+          $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
           SERVER2_PID=$!
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
-            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
-            --initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
+          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
+            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
           SERVER3_PID=$!
           # ^-- chunking test
 
-          python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
-            --initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
+          $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
           SERVER4_PID=$!
           # ^-- tensor parallelism test (not compatible with adapters yet)
 
@@ -102,6 +99,9 @@ jobs:
           export no_proxy=*
           export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
 
+          # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
+          export PETALS_MAX_RETRIES=10
+
           pytest tests --durations=0 --durations-min=1.0 -v
 
           # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
@@ -118,4 +118,3 @@ jobs:
           # [Step 4] Clean up
 
           kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
-          echo "Done!"

+ 12 - 97
README.md

@@ -8,14 +8,14 @@
     <br>
 </p>
 
-Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 ```python
 from transformers import AutoTokenizer
 from petals import AutoDistributedModelForCausalLM
 
 # Choose any model available at https://health.petals.dev
-model_name = "petals-team/StableBeluga2"
+model_name = "petals-team/StableBeluga2"  # This one is fine-tuned Llama 2 (70B)
 
 # Connect to a distributed network hosting model layers
 tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -31,9 +31,9 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
 </p>
 
-🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
+🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
-🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
+🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
 
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
@@ -81,9 +81,8 @@ python3 -m petals.cli.run_server petals-team/StableBeluga2
 
 ## How does it work?
 
-- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
-- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
-- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
+- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps.
+- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.
 
 <p align="center">
     <img src="https://i.imgur.com/RTYF3yW.png" width="800">
@@ -113,99 +112,15 @@ Advanced guides:
 - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
 
-## Benchmarks
-
-The benchmarks below are for BLOOM-176B:
-
-<table align="center">
-  <tr>
-    <th colspan="2">Network</th>
-    <th colspan="2">Single-batch inference<br>(steps/s)</th>
-    <th colspan="2">Parallel forward<br>(tokens/s)</th>
-  </tr>
-  <tr>
-    <th rowspan="2">Bandwidth</th>
-    <th rowspan="2">Round-trip<br>latency</th>
-    <th colspan="2">Sequence length</th>
-    <th colspan="2">Batch size</th>
-  </tr>
-  <tr align="center">
-    <td>128</td>
-    <td>2048</td>
-    <td>1</td>
-    <td>64</td>
-  </tr>
-  <tr>
-    <th colspan="6">Offloading, max. possible speed on 1x A100 <sup>1</sup></th>
-  </tr>
-  <tr align="center">
-    <td>256 Gbit/s</td>
-    <td></td>
-    <td>0.18</td>
-    <td>0.18</td>
-    <td>2.7</td>
-    <td>170.3</td>
-  </tr>
-  <tr align="center">
-    <td>128 Gbit/s</td>
-    <td></td>
-    <td>0.09</td>
-    <td>0.09</td>
-    <td>2.4</td>
-    <td>152.8</td>
-  </tr>
-  <tr>
-    <th colspan="6">Petals on 14 heterogeneous servers across Europe and North America <sup>2</sup></th>
-  </tr>
-  <tr align="center">
-    <td colspan="2">Real world</td>
-    <td>0.83</td>
-    <td>0.79</td>
-    <td>32.6</td>
-    <td>179.4</td>
-  </tr>
-  <tr>
-    <th colspan="6">Petals on 3 servers, with one A100 each <sup>3</sup></th>
-  </tr>
-  <tr align="center">
-    <td>1 Gbit/s</td>
-    <td>&lt; 5 ms</td>
-    <td>1.71</td>
-    <td>1.54</td>
-    <td>70.0</td>
-    <td>253.6</td>
-  </tr>
-  <tr align="center">
-    <td>100 Mbit/s</td>
-    <td>&lt; 5 ms</td>
-    <td>1.66</td>
-    <td>1.49</td>
-    <td>56.4</td>
-    <td>182.0</td>
-  </tr>
-  <tr align="center">
-    <td>100 Mbit/s</td>
-    <td>100 ms</td>
-    <td>1.23</td>
-    <td>1.11</td>
-    <td>19.7</td>
-    <td>112.2</td>
-  </tr>
-</table>
-
-<sup>1</sup> **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch.
-
-<sup>2</sup> **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 100–1000 Mbit/s. 4 servers operate from under firewalls.
-
-<sup>3</sup> **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU.
-
-We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
-
-## 🛠️ Contributing
+### Benchmarks
+
+Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
+
+### 🛠️ Contributing
 
 Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
 
-## 📜 Citation
+### 📜 Citation
 
 Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
 [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)

+ 2 - 2
setup.cfg

@@ -37,7 +37,7 @@ install_requires =
     accelerate>=0.22.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
+    transformers>=4.32.0,<4.35.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     hivemind==1.1.10.post2
@@ -47,7 +47,7 @@ install_requires =
     cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft>=0.5.0
+    peft==0.5.0
     safetensors>=0.3.1
     Dijkstar>=2.6.0
 

+ 3 - 3
src/petals/__init__.py

@@ -17,13 +17,13 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "2.1.0"
+__version__ = "2.3.0.dev1"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
+        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
     assert version.parse("1.1.10") <= version.parse(
         hivemind.__version__
     ), "Please install a proper hivemind version: pip install hivemind>=1.1.10"

+ 3 - 3
src/petals/cli/run_server.py

@@ -70,17 +70,17 @@ def main():
 
     parser.add_argument('--inference_max_length', type=int, default=None,
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
-                             'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+                             'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
     parser.add_argument('--min_batch_size', type=int, default=1,
                         help='Minimum required batch size for all operations (in total tokens)')
     parser.add_argument('--max_batch_size', type=int, default=None,
                         help='The total number of tokens in the same batch will not exceed this value. '
-                             'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+                             'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
     parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
                         help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
     parser.add_argument('--attn_cache_tokens', type=int, default=None,
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
-                             'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
+                             'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
 
     parser.add_argument('--cache_dir', type=str, default=None,
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')

+ 5 - 1
src/petals/client/config.py

@@ -1,10 +1,14 @@
 import dataclasses
+import os
 from typing import Optional, Sequence, Union
 
 from hivemind import PeerID
 
 from petals.constants import PUBLIC_INITIAL_PEERS
 
+_max_retries = os.getenv("PETALS_MAX_RETRIES")
+DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
+
 
 @dataclasses.dataclass
 class ClientConfig:
@@ -21,7 +25,7 @@ class ClientConfig:
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     update_period: float = 60  # refresh DHT information once in this many seconds
 
-    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
+    max_retries: Optional[int] = DEFAULT_MAX_RETRIES  # max number of retries before an exception (default: inf)
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds

+ 1 - 9
src/petals/client/from_pretrained.py

@@ -6,7 +6,6 @@ import tempfile
 from contextvars import ContextVar
 from typing import List, Optional, Tuple, Union
 
-import torch
 from hivemind.utils.logging import get_logger
 from transformers import BloomPreTrainedModel, modeling_utils
 
@@ -22,21 +21,14 @@ class FromPretrainedMixin:
         model_name_or_path: Union[str, os.PathLike, None],
         *args,
         low_cpu_mem_usage: Optional[bool] = None,
-        torch_dtype: Optional[Union[str, torch.dtype]] = None,
         **kwargs,
     ):
         model_name_or_path = get_compatible_model_repo(model_name_or_path)
         if low_cpu_mem_usage is None:
             low_cpu_mem_usage = True
-        if torch_dtype is None:
-            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
-            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
-            torch_dtype = "auto"
 
         with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
-            return super().from_pretrained(
-                model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
-            )
+            return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
 
     from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
         "low_cpu_mem_usage(`bool`, *optional*)",

+ 10 - 0
src/petals/client/inference_session.py

@@ -305,11 +305,21 @@ class InferenceSession:
         else:
             assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
             assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (inputs.shape[0], 1)
+            assert prompts.shape[2] <= inputs.shape[1]
+            assert prompts.shape[3] == inputs.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY_INT64
+        else:
+            assert len(hypo_ids) == len(inputs)
+            assert hypo_ids.dtype == torch.int64
 
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
         prompts = prompts.cpu()
+        hypo_ids = hypo_ids.cpu()
         step_id = str(uuid.uuid4())
 
         n_input_tokens = inputs.shape[1]

+ 5 - 7
src/petals/client/lm_head.py

@@ -1,8 +1,7 @@
 import dataclasses
 import platform
-from typing import Optional, Union
+from typing import Union
 
-import psutil
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
@@ -68,11 +67,10 @@ class LMHead(nn.Module):
         assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
 
         if not self._bf16_warning_shown:
-            if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
-                logger.warning(
-                    "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
-                    "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
-                )
+            logger.warning(
+                "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
+                "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
+            )
             self._bf16_warning_shown = True
 
         hidden_states = hidden_states.float()

+ 13 - 52
src/petals/client/routing/sequence_info.py

@@ -1,17 +1,15 @@
 import dataclasses
 import time
-from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
+from typing import Iterable, List, Optional, Tuple
 
 from hivemind import get_logger
 
 from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.utils.dht import compute_spans
 
 logger = get_logger(__name__)
 
 
-T = TypeVar("T")
-
-
 @dataclasses.dataclass
 class RemoteSequenceInfo:
     """
@@ -30,7 +28,7 @@ class RemoteSequenceInfo:
     last_updated_time: Optional[float]
 
     @classmethod
-    def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
+    def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
         block_uids = tuple(block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_spans = tuple([] for _ in range(len(block_uids)))
@@ -39,7 +37,7 @@ class RemoteSequenceInfo:
     def __getitem__(self, ix: slice):
         assert isinstance(ix, slice)
         block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
-        spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
+        spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
         return RemoteSequenceInfo(
             block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
         )
@@ -47,60 +45,23 @@ class RemoteSequenceInfo:
     def __len__(self):
         return len(self.block_uids)
 
-    def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
+    def update_(self, new_block_infos: List[RemoteModuleInfo]):
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.debug(f"Found no block info for block {uid}")
-                continue
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-                continue
-            if not info.servers:
-                logger.debug(f"Found no active peers for block {uid}")
-                continue
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-                continue
+            assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
             self.block_infos[block_index].servers = info.servers
 
-        self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+        self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
         self.last_updated_time = time.perf_counter()
 
     @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            if info is not None:
-                for peer_id, server_info in info.servers.items():
-                    if server_info.state != ServerState.ONLINE:
-                        continue
-                    if peer_id not in active_spans:
-                        active_spans[peer_id] = RemoteSpanInfo(
-                            peer_id=peer_id,
-                            start=block_index,
-                            end=block_index + 1,
-                            server_info=server_info,
-                        )
-                    else:  # peer_id in active_spans
-                        active_spans[peer_id].end = block_index + 1
-
-            for peer_id in list(active_spans.keys()):
-                if (
-                    info is None
-                    or peer_id not in info.servers
-                    or info.servers[peer_id].state != ServerState.ONLINE
-                    or block_index == len(block_infos) - 1
-                ):
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans, f"spans: {active_spans}"
-
-        closed_spans.sort(key=lambda span: span.length, reverse=True)
+    def _sort_spans(block_infos: List[RemoteModuleInfo]):
+        spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
+        spans_by_priority.sort(key=lambda span: span.length, reverse=True)
 
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
+        spans_containing_block = tuple([] for _ in range(len(block_infos)))
+        for span in spans_by_priority:
             for block_index in range(span.start, span.end):
                 spans_containing_block[block_index].append(span)
 
-        return closed_spans, spans_containing_block
+        return spans_by_priority, spans_containing_block

+ 0 - 4
src/petals/client/routing/sequence_manager.py

@@ -117,7 +117,6 @@ class RemoteSequenceManager:
         if state.sequence_info.last_updated_time is not None:
             assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
-            self._need_latest_infos = True
 
     @staticmethod
     def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
@@ -346,9 +345,6 @@ class RemoteSequenceManager:
         )
 
         for block_info in new_block_infos:
-            if not block_info:
-                continue
-
             # Apply allow and block lists
             block_info.servers = {
                 peer_id: server_info

+ 26 - 9
src/petals/data_structures.py

@@ -11,18 +11,15 @@ UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 
 
-class ServerState(Enum):
-    OFFLINE = 0
-    JOINING = 1
-    ONLINE = 2
-
-
-RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
+    assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
+    dht_prefix, index = uid.split(UID_DELIMITER)
+    return dht_prefix, int(index)
 
 
 @pydantic.dataclasses.dataclass
 class ModelInfo:
-    num_blocks: int
+    num_blocks: pydantic.conint(ge=1, strict=True)
     repository: Optional[str] = None
 
     def to_dict(self) -> dict:
@@ -33,11 +30,23 @@ class ModelInfo:
         return cls(**source)
 
 
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+
+
 @pydantic.dataclasses.dataclass
 class ServerInfo:
     state: ServerState
     throughput: RPS
 
+    start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
+    end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
+
     public_name: Optional[str] = None
     version: Optional[str] = None
 
@@ -83,9 +92,17 @@ class RemoteSpanInfo:
     server_info: ServerInfo
 
     @property
-    def length(self):
+    def length(self) -> int:
         return self.end - self.start
 
+    @property
+    def state(self) -> ServerState:
+        return self.server_info.state
+
+    @property
+    def throughput(self) -> float:
+        return self.server_info.throughput
+
 
 RPCInfo = Dict[str, Any]
 

+ 3 - 0
src/petals/models/falcon/config.py

@@ -31,6 +31,9 @@ class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig,
     def from_pretrained(
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
+        if "180B" in model_name_or_path.upper():
+            logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
+
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         if loading_from_repo and dht_prefix is None:
             dht_prefix = str(model_name_or_path)

+ 4 - 0
src/petals/models/falcon/model.py

@@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         input_ids: Optional[torch.LongTensor] = None,
         past_key_values: Optional[RemotePastKeyValues] = None,
         attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
         head_mask: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.LongTensor] = None,
         use_cache: Optional[bool] = None,
@@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert (
+            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
+        ), f"Non-consecutive position_ids are not supported, {position_ids=}"
         assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
         assert use_cache is None or use_cache, f"{use_cache=} is not supported"
         assert not output_attentions, f"{output_attentions=} is not supported"

+ 209 - 10
src/petals/models/llama/block.py

@@ -3,13 +3,219 @@ LLaMA intermediate layer
 Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
 See commit history for authorship.
 """
+import math
 from typing import Optional, Tuple
 
 import torch
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    LlamaConfig,
+    LlamaDecoderLayer,
+    LlamaMLP,
+    LlamaModel,
+    LlamaRMSNorm,
+    repeat_kv,
+    rotate_half,
+)
 
+from petals.utils.cuda_graphs import make_inference_graphed_callable
+
+
+def apply_rotary_pos_emb(q, k, cos, sin):
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class OptimizedLlamaAttention(LlamaAttention):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._rotary_graph = None
+
+    def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
+        if self._rotary_graph is None:
+            self._rotary_graph = make_inference_graphed_callable(
+                apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
+            )
+        return self._rotary_graph(query_states, key_states, cos, sin)
 
-class WrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        assert not output_attentions
+        assert position_ids is None
+        bsz, q_len, _ = hidden_states.size()
+
+        if self.config.pretraining_tp > 1:
+            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+            query_slices = self.q_proj.weight.split(
+                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+            )
+            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+            query_states = torch.cat(query_states, dim=-1)
+
+            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+            key_states = torch.cat(key_states, dim=-1)
+
+            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+            value_states = torch.cat(value_states, dim=-1)
+
+        else:
+            query_states = self.q_proj(hidden_states)
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        cos = cos[:, :, kv_seq_len - q_len :]
+        sin = sin[:, :, kv_seq_len - q_len :]
+
+        if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
+        else:
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:
+            attn_weights = attn_weights + attention_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        if self.config.pretraining_tp > 1:
+            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+        else:
+            attn_output = self.o_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
+    def __init__(self, config: LlamaConfig):
+        nn.Module.__init__(self)
+        self.hidden_size = config.hidden_size
+        self.self_attn = OptimizedLlamaAttention(config=config)
+        self.mlp = LlamaMLP(config)
+        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+        self.pre_attn_graph = None
+        self.post_attn_graph = None
+
+    def _optimized_input_layernorm(self, hidden_states):
+        if self.pre_attn_graph is None:
+            self.pre_attn_graph = make_inference_graphed_callable(
+                self.input_layernorm.forward, sample_args=(hidden_states,)
+            )
+        return self.pre_attn_graph(hidden_states)
+
+    def _optimized_output_layernorm(self, hidden_states):
+        if self.post_attn_graph is None:
+            self.post_attn_graph = make_inference_graphed_callable(
+                self.post_attention_layernorm.forward, sample_args=(hidden_states,)
+            )
+        return self.post_attn_graph(hidden_states)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+        """
+
+        residual = hidden_states
+
+        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            hidden_states = self._optimized_input_layernorm(hidden_states)
+        else:
+            hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+
+        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            hidden_states = self._optimized_output_layernorm(hidden_states)
+        else:
+            hidden_states = self.post_attention_layernorm(hidden_states)
+
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
     def forward(
         self,
         hidden_states: torch.Tensor,
@@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
             seq_length_with_past = seq_length_with_past + past_key_values_length
             past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
 
-        if position_ids is None:
-            device = hidden_states.device
-            position_ids = torch.arange(
-                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
-            )
-            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
-        else:
-            position_ids = position_ids.view(-1, seq_length).long()
+        assert position_ids is None
 
         # embed positions
         if attention_mask is None:

+ 24 - 48
src/petals/server/block_selection.py

@@ -1,54 +1,23 @@
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List
 
 import numpy as np
 from hivemind import PeerID, get_logger
 
-from petals.data_structures import RemoteModuleInfo, ServerState
-
-__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
+from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.utils.dht import compute_spans
 
 logger = get_logger(__name__)
 
 
-@dataclass
-class Span:
-    start: int
-    end: int
-    throughput: float
-    state: ServerState
-
-    @property
-    def length(self):
-        return self.end - self.start
-
-    def move_to(self, new_start: int) -> None:
-        self.start, self.end = new_start, new_start + self.length
-
-
-def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
-    spans = {}
-    throughputs = np.zeros(len(module_infos))
-    for block, module in enumerate(module_infos):
-        if module is None:
-            continue
-
-        # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
-        # If the order were not defined, we would get slightly different values due to floating point errors,
-        # which may cause excess block replacements.
-        for peer_id, server in sorted(module.servers.items()):
-            if server.state == ServerState.OFFLINE:
-                continue
+def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
+    # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
+    # If the order were not defined, we would get slightly different values due to floating point errors,
+    # which may cause excess block replacements.
 
-            if peer_id in spans:
-                spans[peer_id].start = min(spans[peer_id].start, block)
-                spans[peer_id].end = max(spans[peer_id].start, block + 1)
-            else:
-                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
-
-            throughputs[block] += server.throughput
-
-    return spans, throughputs
+    throughputs = np.zeros(total_blocks)
+    for span in sorted(spans.values(), key=lambda span: span.peer_id):
+        throughputs[span.start : span.end] += span.throughput
+    return throughputs
 
 
 def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
     return min(options)[-1]
 
 
-def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
-    _, throughputs = compute_spans(module_infos)
+def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
+    spans = compute_spans(module_infos, min_state=ServerState.JOINING)
+    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
+
     start = _choose_best_start(throughputs, num_blocks)
     return list(range(start, start + num_blocks))
 
 
+def _move_span(span: RemoteSpanInfo, new_start: int):
+    span.start, span.end = new_start, new_start + span.length
+
+
 def should_choose_other_blocks(
-    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
+    local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
 ) -> bool:
     if balance_quality > 1.0:
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
 
-    spans, throughputs = compute_spans(module_infos)
+    spans = compute_spans(module_infos, min_state=ServerState.JOINING)
+    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
     initial_throughput = throughputs.min()
     eps = 1e-3
 
@@ -88,7 +64,7 @@ def should_choose_other_blocks(
         return False  # This server is on its best place already
 
     throughputs[local_span.start : local_span.end] += local_span.throughput * eps
-    local_span.move_to(new_start)
+    _move_span(local_span, new_start)
     throughputs[local_span.start : local_span.end] += local_span.throughput
 
     moved = True
@@ -105,7 +81,7 @@ def should_choose_other_blocks(
 
             throughputs[span.start : span.end] += span.throughput * eps
             if span.start != new_start:
-                span.move_to(new_start)
+                _move_span(span, new_start)
                 moved = True
             throughputs[span.start : span.end] += span.throughput
 

+ 16 - 13
src/petals/server/server.py

@@ -24,7 +24,7 @@ from transformers import PretrainedConfig
 
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -204,7 +204,7 @@ class Server:
 
         # For attention cache in GPU or RAM
         if attn_cache_tokens is None:
-            attn_cache_tokens = 32768 if is_multiquery_attn else 8192
+            attn_cache_tokens = 16384 if is_multiquery_attn else 4096
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block //= self.block_config.num_key_value_groups
         self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
@@ -221,11 +221,10 @@ class Server:
             num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
         if block_indices is not None:
             try:
-                first_block_index, last_block_index = block_indices.split(":")
-                first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
+                start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
             except Exception as e:
                 raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
-            block_indices = range(first_block_index, last_block_index)
+            block_indices = range(start_block, end_block)
             num_blocks = len(block_indices)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
 
@@ -704,11 +703,16 @@ class ModuleAnnouncerThread(threading.Thread):
         self.expiration = expiration
         self.trigger = threading.Event()
 
+        self.dht_prefix = parse_uid(module_uids[0])[0]
+        block_indices = [parse_uid(uid)[1] for uid in module_uids]
+        self.server_info.start_block = min(block_indices)
+        self.server_info.end_block = max(block_indices) + 1
+
         self.max_pinged = max_pinged
-        self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
-        block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
-        start_block, end_block = min(block_indices), max(block_indices) + 1
-        self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
+        self.next_uids = [
+            f"{self.dht_prefix}{UID_DELIMITER}{i}"
+            for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
+        ]
         self.ping_aggregator = PingAggregator(self.dht)
 
     def run(self) -> None:
@@ -756,12 +760,11 @@ class ModuleAnnouncerThread(threading.Thread):
 
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
         module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
-        middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
+        middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
         pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
         pinged_servers.discard(self.dht.peer_id)
-        if module_infos[-1] is not None:
-            # Sample servers hosting the block after the last one (most likely continuations) separately
-            pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
+        # Sample servers hosting the block after the last one (most likely continuations) separately
+        pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
         self.ping_aggregator.ping(list(pinged_servers))
 
 

+ 1 - 1
src/petals/server/throughput.py

@@ -56,7 +56,7 @@ def get_server_throughput(
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
-    with open(lock_path, "wb") as lock_fd:
+    with open(lock_path, "wb+") as lock_fd:
         logger.info("Loading throughput info")
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed

+ 76 - 0
src/petals/utils/cuda_graphs.py

@@ -0,0 +1,76 @@
+import torch
+from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
+
+
+def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
+    """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
+    assert not isinstance(callable, torch.nn.Module)
+    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
+        raise RuntimeError(
+            "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
+        )
+
+    flatten_arg, _ = _tree_flatten(sample_args)
+    flatten_sample_args = tuple(flatten_arg)
+    assert all(
+        isinstance(arg, torch.Tensor) for arg in flatten_arg
+    ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
+
+    len_user_args = len(sample_args)
+    static_input_surface = flatten_sample_args
+
+    graph = torch.cuda.CUDAGraph()
+
+    # Warmup
+    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
+    # from ending up in any captures.
+    s = torch.cuda.Stream()
+    s.wait_stream(torch.cuda.current_stream())
+    with torch.cuda.stream(s):
+        for _ in range(num_warmup_iters):
+            outputs, _ = _tree_flatten(callable(*sample_args))
+        del outputs
+    torch.cuda.current_stream().wait_stream(s)
+
+    # Capture forward graph
+    with torch.cuda.graph(graph):
+        outputs = callable(*sample_args)
+
+    flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
+    static_outputs = tuple(flatten_outputs)
+
+    def make_graphed_function(
+        graph,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+    ):
+        def replay_graph(*inputs):
+            # At this stage, only the user args may (potentially) be new tensors.
+            for i in range(len_user_args):
+                if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
+                    static_input_surface[i].copy_(inputs[i])
+            graph.replay()
+            assert isinstance(static_outputs, tuple)
+            return tuple(o.detach() for o in static_outputs)
+
+        def functionalized(*user_args):
+            # Runs the autograd function with inputs == all inputs to the graph that might require grad
+            # (explicit user args + module parameters)
+            # Assumes module params didn't change since capture.
+            flatten_user_args, _ = _tree_flatten(user_args)
+            out = replay_graph(*flatten_user_args)
+            return _tree_unflatten(out, output_unflatten_spec)
+
+        return functionalized
+
+    # Put together the final graphed callable
+    graphed = make_graphed_function(
+        graph,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+    )
+    return graphed

+ 41 - 12
src/petals/utils/dht.py

@@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
+from petals.data_structures import (
+    CHAIN_DELIMITER,
+    UID_DELIMITER,
+    ModuleUID,
+    RemoteModuleInfo,
+    RemoteSpanInfo,
+    ServerInfo,
+    ServerState,
+    parse_uid,
+)
 
 logger = get_logger(__name__)
 
@@ -70,7 +79,7 @@ def get_remote_module_infos(
     *,
     latest: bool = False,
     return_future: bool = False,
-) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
+) -> Union[List[RemoteModuleInfo], MPFuture]:
     return dht.run_coroutine(
         partial(
             _get_remote_module_infos,
@@ -90,7 +99,7 @@ async def _get_remote_module_infos(
     active_adapter: Optional[str],
     expiration_time: Optional[DHTExpiration],
     latest: bool,
-) -> List[Optional[RemoteModuleInfo]]:
+) -> List[RemoteModuleInfo]:
     if latest:
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
         expiration_time = math.inf
@@ -99,14 +108,14 @@ async def _get_remote_module_infos(
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
-    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
-    for i, uid in enumerate(uids):
-        metadata = found[uid]
+    modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
+    for module_info in modules:
+        metadata = found[module_info.uid]
         if metadata is None or not isinstance(metadata.value, dict):
             if metadata is not None:
-                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
+                logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
             continue
-        servers = {}
+
         for peer_id, server_info in metadata.value.items():
             try:
                 peer_id = PeerID.from_base58(peer_id)
@@ -116,9 +125,29 @@ async def _get_remote_module_infos(
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     continue
 
-                servers[peer_id] = server_info
+                module_info.servers[peer_id] = server_info
             except (TypeError, ValueError) as e:
-                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
-        if servers:
-            modules[i] = RemoteModuleInfo(uid, servers)
+                logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
     return modules
+
+
+def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
+    block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
+    num_blocks = len(module_infos)
+
+    spans = {}
+    for block_idx, module_info in enumerate(module_infos):
+        for peer_id, server_info in sorted(module_info.servers.items()):
+            if server_info.state.value < min_state.value:
+                continue
+
+            if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
+                spans[peer_id] = RemoteSpanInfo(
+                    peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
+                )
+                if server_info.start_block is not None and server_info.end_block is not None:
+                    spans[peer_id].start = max(server_info.start_block - block_offset, 0)
+                    spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
+            elif spans[peer_id].state == server_info.state:
+                spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
+    return spans

+ 1 - 1
src/petals/utils/disk_cache.py

@@ -22,7 +22,7 @@ def _blocks_lock(cache_dir: Optional[str], mode: int):
     lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
 
     os.makedirs(lock_path.parent, exist_ok=True)
-    with open(lock_path, "wb") as lock_fd:
+    with open(lock_path, "wb+") as lock_fd:
         fcntl.flock(lock_fd.fileno(), mode)
         # The OS will release the lock when lock_fd is closed or the process is killed
         yield

+ 93 - 5
tests/test_optimized_layers.py

@@ -3,6 +3,7 @@ from typing import Optional, Tuple
 import pytest
 import torch
 from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
 
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, convert_block
@@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
         return state
 
 
-@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
+class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        batch_size, seq_length, _ = hidden_states.shape
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        past_key_value = layer_past
+        if past_key_value is not None:
+            past_key_values_length = past_key_value[0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
+
+        if position_ids is None:
+            device = hidden_states.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+            )
+        attention_mask = LlamaModel._prepare_decoder_attention_mask(
+            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        )
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_llama_to_bloom(
+                present_key_value, batch_size, seq_length_with_past
+            )
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_llama(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        key_states = key_states.permute(0, 2, 1)
+        key_states = key_states.view(
+            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        value_states = value_states.view(*key_states.shape)
+        return (key_states, value_states)
+
+    def _reorder_cache_from_llama_to_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        value_states = value_states.view(
+            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        key_states = key_states.view(*value_states.shape)
+        key_states = key_states.permute(0, 2, 1)
+        return (key_states, value_states)
+
+
 @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
 @pytest.mark.forked
-def test_falcon(device):
+def test_optimized_block(device):
     if device == "cuda:0" and not torch.cuda.is_available():
         pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
 
@@ -108,11 +190,17 @@ def test_falcon(device):
     quant_type = QuantType.NONE
 
     block = config.block_class(config).to(dtype)
-    block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+    block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+
+    if config.model_type == "falcon":
+        unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
+    elif config.model_type == "llama":
+        unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
+    else:
+        pytest.skip(f"This test is not applicable to {config.model_type} models")
 
-    unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
     unopt_block = convert_block(
-        unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
+        unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
     )
 
     unopt_block.load_state_dict(block.state_dict())