Ver código fonte

Merge remote-tracking branch 'origin/main' into forward_kwargs

# Conflicts:
#	src/petals/__init__.py
#	src/petals/client/inference_session.py
Your Name 1 ano atrás
pai
commit
3195579620

+ 12 - 13
.github/workflows/run-tests.yaml

@@ -48,7 +48,6 @@ jobs:
           export MODEL_NAME="${{ matrix.model }}"
           export MODEL_NAME="${{ matrix.model }}"
           export REF_NAME="${{ matrix.model }}"
           export REF_NAME="${{ matrix.model }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
-          export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
 
 
           # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
           # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 
 
@@ -61,27 +60,25 @@ jobs:
 
 
           until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
           until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init
 
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
-            --mean_balance_check_period 10 \
-            --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
+          export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
+            --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
+          export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
+
+          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
           SERVER1_PID=$!
           SERVER1_PID=$!
           # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
           # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
 
 
           sleep 10  # wait for the 1st server to choose blocks
           sleep 10  # wait for the 1st server to choose blocks
 
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
-            --identity_path tests/server2.id \
-            --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
+          $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
           SERVER2_PID=$!
           SERVER2_PID=$!
 
 
-          python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
-            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
-            --initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
+          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
+            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
           SERVER3_PID=$!
           SERVER3_PID=$!
           # ^-- chunking test
           # ^-- chunking test
 
 
-          python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
-            --initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
+          $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
           SERVER4_PID=$!
           SERVER4_PID=$!
           # ^-- tensor parallelism test (not compatible with adapters yet)
           # ^-- tensor parallelism test (not compatible with adapters yet)
 
 
@@ -102,6 +99,9 @@ jobs:
           export no_proxy=*
           export no_proxy=*
           export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
           export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
 
 
+          # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
+          export PETALS_MAX_RETRIES=10
+
           pytest tests --durations=0 --durations-min=1.0 -v
           pytest tests --durations=0 --durations-min=1.0 -v
 
 
           # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
           # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
@@ -118,4 +118,3 @@ jobs:
           # [Step 4] Clean up
           # [Step 4] Clean up
 
 
           kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
           kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
-          echo "Done!"

+ 12 - 97
README.md

@@ -8,14 +8,14 @@
     <br>
     <br>
 </p>
 </p>
 
 
-Generate text with distributed **Llama 2 (70B)**, **Stable Beluga 2**, **Guanaco-65B** or **BLOOM-176B** and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 
 ```python
 ```python
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from petals import AutoDistributedModelForCausalLM
 from petals import AutoDistributedModelForCausalLM
 
 
 # Choose any model available at https://health.petals.dev
 # Choose any model available at https://health.petals.dev
-model_name = "petals-team/StableBeluga2"
+model_name = "petals-team/StableBeluga2"  # This one is fine-tuned Llama 2 (70B)
 
 
 # Connect to a distributed network hosting model layers
 # Connect to a distributed network hosting model layers
 tokenizer = AutoTokenizer.from_pretrained(model_name)
 tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -31,9 +31,9 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
 </p>
 </p>
 
 
-🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
+🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
 
-🔏 **Privacy.** Your data will be processed by other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
+🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
 
 
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
 
@@ -81,9 +81,8 @@ python3 -m petals.cli.run_server petals-team/StableBeluga2
 
 
 ## How does it work?
 ## How does it work?
 
 
-- Petals runs large language models like [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) and [BLOOM](https://huggingface.co/bigscience/bloom) **collaboratively** — you load a small part of the model, then join people serving the other parts to run inference or fine-tuning.
-- Single-batch inference runs at **up to 6 steps/sec** for **Llama 2** (70B) and &approx; 1 step/sec for BLOOM-176B. This is [up to 10x faster](https://github.com/bigscience-workshop/petals#benchmarks) than offloading, enough to build [chatbots](https://chat.petals.dev) and other interactive apps. Parallel inference reaches hundreds of tokens/sec.
-- Beyond classic language model APIs — you can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of PyTorch.
+- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps.
+- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.
 
 
 <p align="center">
 <p align="center">
     <img src="https://i.imgur.com/RTYF3yW.png" width="800">
     <img src="https://i.imgur.com/RTYF3yW.png" width="800">
@@ -113,99 +112,15 @@ Advanced guides:
 - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
 - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
 - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
 
 
-## Benchmarks
-
-The benchmarks below are for BLOOM-176B:
-
-<table align="center">
-  <tr>
-    <th colspan="2">Network</th>
-    <th colspan="2">Single-batch inference<br>(steps/s)</th>
-    <th colspan="2">Parallel forward<br>(tokens/s)</th>
-  </tr>
-  <tr>
-    <th rowspan="2">Bandwidth</th>
-    <th rowspan="2">Round-trip<br>latency</th>
-    <th colspan="2">Sequence length</th>
-    <th colspan="2">Batch size</th>
-  </tr>
-  <tr align="center">
-    <td>128</td>
-    <td>2048</td>
-    <td>1</td>
-    <td>64</td>
-  </tr>
-  <tr>
-    <th colspan="6">Offloading, max. possible speed on 1x A100 <sup>1</sup></th>
-  </tr>
-  <tr align="center">
-    <td>256 Gbit/s</td>
-    <td></td>
-    <td>0.18</td>
-    <td>0.18</td>
-    <td>2.7</td>
-    <td>170.3</td>
-  </tr>
-  <tr align="center">
-    <td>128 Gbit/s</td>
-    <td></td>
-    <td>0.09</td>
-    <td>0.09</td>
-    <td>2.4</td>
-    <td>152.8</td>
-  </tr>
-  <tr>
-    <th colspan="6">Petals on 14 heterogeneous servers across Europe and North America <sup>2</sup></th>
-  </tr>
-  <tr align="center">
-    <td colspan="2">Real world</td>
-    <td>0.83</td>
-    <td>0.79</td>
-    <td>32.6</td>
-    <td>179.4</td>
-  </tr>
-  <tr>
-    <th colspan="6">Petals on 3 servers, with one A100 each <sup>3</sup></th>
-  </tr>
-  <tr align="center">
-    <td>1 Gbit/s</td>
-    <td>&lt; 5 ms</td>
-    <td>1.71</td>
-    <td>1.54</td>
-    <td>70.0</td>
-    <td>253.6</td>
-  </tr>
-  <tr align="center">
-    <td>100 Mbit/s</td>
-    <td>&lt; 5 ms</td>
-    <td>1.66</td>
-    <td>1.49</td>
-    <td>56.4</td>
-    <td>182.0</td>
-  </tr>
-  <tr align="center">
-    <td>100 Mbit/s</td>
-    <td>100 ms</td>
-    <td>1.23</td>
-    <td>1.11</td>
-    <td>19.7</td>
-    <td>112.2</td>
-  </tr>
-</table>
-
-<sup>1</sup> **An upper bound for offloading performance.** We base our offloading numbers on the best possible hardware setup for offloading: CPU RAM offloading via PCIe 4.0 with 16 PCIe lanes per GPU and PCIe switches for pairs of GPUs. We assume zero latency for the upper bound estimation. In 8-bit, the model uses 1 GB of memory per billion parameters. PCIe 4.0 with 16 lanes has a throughput of 256 Gbit/s, so offloading 176B parameters takes 5.5 seconds. The throughput is twice as slow (128 Gbit/s) if we have two GPUs behind the same PCIe switch.
-
-<sup>2</sup> **A real-world distributed setting** with 14 servers holding 2× RTX 3060, 4× 2080Ti, 2× 3090, 2× A4000, and 4× A5000 GPUs. These are personal servers and servers from university labs, spread across Europe and North America and connected to the Internet at speeds of 100–1000 Mbit/s. 4 servers operate from under firewalls.
-
-<sup>3</sup> **An optimistic setup** that requires least communication. The client nodes have 8 CPU cores and no GPU.
-
-We provide more evaluations and discuss these results in more detail in **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
-
-## 🛠️ Contributing
+### Benchmarks
+
+Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
+
+### 🛠️ Contributing
 
 
 Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
 Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
 
 
-## 📜 Citation
+### 📜 Citation
 
 
 Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
 Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
 [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
 [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)

+ 2 - 2
setup.cfg

@@ -37,7 +37,7 @@ install_requires =
     accelerate>=0.22.0
     accelerate>=0.22.0
     huggingface-hub>=0.11.1,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
     tokenizers>=0.13.3
-    transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
+    transformers>=4.32.0,<4.35.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     hivemind==1.1.10.post2
     hivemind==1.1.10.post2
@@ -47,7 +47,7 @@ install_requires =
     cpufeature>=0.2.0; platform_machine == "x86_64"
     cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     packaging>=20.9
     sentencepiece>=0.1.99
     sentencepiece>=0.1.99
-    peft>=0.5.0
+    peft==0.5.0
     safetensors>=0.3.1
     safetensors>=0.3.1
     Dijkstar>=2.6.0
     Dijkstar>=2.6.0
 
 

+ 3 - 3
src/petals/__init__.py

@@ -17,13 +17,13 @@ from petals.models import *
 from petals.utils import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 from petals.utils.logging import initialize_logs as _initialize_logs
 
 
-__version__ = "2.1.0"
+__version__ = "2.3.0.dev1"
 
 
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
     assert (
-        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
+        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
     assert version.parse("1.1.10") <= version.parse(
     assert version.parse("1.1.10") <= version.parse(
         hivemind.__version__
         hivemind.__version__
     ), "Please install a proper hivemind version: pip install hivemind>=1.1.10"
     ), "Please install a proper hivemind version: pip install hivemind>=1.1.10"

+ 3 - 3
src/petals/cli/run_server.py

@@ -70,17 +70,17 @@ def main():
 
 
     parser.add_argument('--inference_max_length', type=int, default=None,
     parser.add_argument('--inference_max_length', type=int, default=None,
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
                         help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
-                             'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+                             'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
     parser.add_argument('--min_batch_size', type=int, default=1,
     parser.add_argument('--min_batch_size', type=int, default=1,
                         help='Minimum required batch size for all operations (in total tokens)')
                         help='Minimum required batch size for all operations (in total tokens)')
     parser.add_argument('--max_batch_size', type=int, default=None,
     parser.add_argument('--max_batch_size', type=int, default=None,
                         help='The total number of tokens in the same batch will not exceed this value. '
                         help='The total number of tokens in the same batch will not exceed this value. '
-                             'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
+                             'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
     parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
     parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
                         help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
                         help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
     parser.add_argument('--attn_cache_tokens', type=int, default=None,
     parser.add_argument('--attn_cache_tokens', type=int, default=None,
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
                         help='The number of past attention key/value pairs that will be stored between inference steps. '
-                             'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
+                             'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
 
 
     parser.add_argument('--cache_dir', type=str, default=None,
     parser.add_argument('--cache_dir', type=str, default=None,
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
                         help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')

+ 5 - 1
src/petals/client/config.py

@@ -1,10 +1,14 @@
 import dataclasses
 import dataclasses
+import os
 from typing import Optional, Sequence, Union
 from typing import Optional, Sequence, Union
 
 
 from hivemind import PeerID
 from hivemind import PeerID
 
 
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.constants import PUBLIC_INITIAL_PEERS
 
 
+_max_retries = os.getenv("PETALS_MAX_RETRIES")
+DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
+
 
 
 @dataclasses.dataclass
 @dataclasses.dataclass
 class ClientConfig:
 class ClientConfig:
@@ -21,7 +25,7 @@ class ClientConfig:
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     update_period: float = 60  # refresh DHT information once in this many seconds
     update_period: float = 60  # refresh DHT information once in this many seconds
 
 
-    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
+    max_retries: Optional[int] = DEFAULT_MAX_RETRIES  # max number of retries before an exception (default: inf)
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds

+ 1 - 9
src/petals/client/from_pretrained.py

@@ -6,7 +6,6 @@ import tempfile
 from contextvars import ContextVar
 from contextvars import ContextVar
 from typing import List, Optional, Tuple, Union
 from typing import List, Optional, Tuple, Union
 
 
-import torch
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from transformers import BloomPreTrainedModel, modeling_utils
 from transformers import BloomPreTrainedModel, modeling_utils
 
 
@@ -22,21 +21,14 @@ class FromPretrainedMixin:
         model_name_or_path: Union[str, os.PathLike, None],
         model_name_or_path: Union[str, os.PathLike, None],
         *args,
         *args,
         low_cpu_mem_usage: Optional[bool] = None,
         low_cpu_mem_usage: Optional[bool] = None,
-        torch_dtype: Optional[Union[str, torch.dtype]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         model_name_or_path = get_compatible_model_repo(model_name_or_path)
         model_name_or_path = get_compatible_model_repo(model_name_or_path)
         if low_cpu_mem_usage is None:
         if low_cpu_mem_usage is None:
             low_cpu_mem_usage = True
             low_cpu_mem_usage = True
-        if torch_dtype is None:
-            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
-            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
-            torch_dtype = "auto"
 
 
         with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
         with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
-            return super().from_pretrained(
-                model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
-            )
+            return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
 
 
     from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
     from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
         "low_cpu_mem_usage(`bool`, *optional*)",
         "low_cpu_mem_usage(`bool`, *optional*)",

+ 10 - 0
src/petals/client/inference_session.py

@@ -305,11 +305,21 @@ class InferenceSession:
         else:
         else:
             assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
             assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
             assert prompts.shape[0] == self.num_blocks
             assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (inputs.shape[0], 1)
+            assert prompts.shape[2] <= inputs.shape[1]
+            assert prompts.shape[3] == inputs.shape[2]
+
+        if hypo_ids is None or is_dummy(hypo_ids):
+            hypo_ids = DUMMY_INT64
+        else:
+            assert len(hypo_ids) == len(inputs)
+            assert hypo_ids.dtype == torch.int64
 
 
         inputs_device = inputs.device
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
         inputs = inputs.cpu()
         prompts = prompts.cpu()
         prompts = prompts.cpu()
+        hypo_ids = hypo_ids.cpu()
         step_id = str(uuid.uuid4())
         step_id = str(uuid.uuid4())
 
 
         n_input_tokens = inputs.shape[1]
         n_input_tokens = inputs.shape[1]

+ 5 - 7
src/petals/client/lm_head.py

@@ -1,8 +1,7 @@
 import dataclasses
 import dataclasses
 import platform
 import platform
-from typing import Optional, Union
+from typing import Union
 
 
-import psutil
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
 import torch.utils.checkpoint
 import torch.utils.checkpoint
@@ -68,11 +67,10 @@ class LMHead(nn.Module):
         assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
         assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
 
 
         if not self._bf16_warning_shown:
         if not self._bf16_warning_shown:
-            if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
-                logger.warning(
-                    "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
-                    "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
-                )
+            logger.warning(
+                "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
+                "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
+            )
             self._bf16_warning_shown = True
             self._bf16_warning_shown = True
 
 
         hidden_states = hidden_states.float()
         hidden_states = hidden_states.float()

+ 13 - 52
src/petals/client/routing/sequence_info.py

@@ -1,17 +1,15 @@
 import dataclasses
 import dataclasses
 import time
 import time
-from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
+from typing import Iterable, List, Optional, Tuple
 
 
 from hivemind import get_logger
 from hivemind import get_logger
 
 
 from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.utils.dht import compute_spans
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-T = TypeVar("T")
-
-
 @dataclasses.dataclass
 @dataclasses.dataclass
 class RemoteSequenceInfo:
 class RemoteSequenceInfo:
     """
     """
@@ -30,7 +28,7 @@ class RemoteSequenceInfo:
     last_updated_time: Optional[float]
     last_updated_time: Optional[float]
 
 
     @classmethod
     @classmethod
-    def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
+    def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
         block_uids = tuple(block_uids)
         block_uids = tuple(block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_spans = tuple([] for _ in range(len(block_uids)))
         empty_spans = tuple([] for _ in range(len(block_uids)))
@@ -39,7 +37,7 @@ class RemoteSequenceInfo:
     def __getitem__(self, ix: slice):
     def __getitem__(self, ix: slice):
         assert isinstance(ix, slice)
         assert isinstance(ix, slice)
         block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
         block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
-        spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
+        spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
         return RemoteSequenceInfo(
         return RemoteSequenceInfo(
             block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
             block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
         )
         )
@@ -47,60 +45,23 @@ class RemoteSequenceInfo:
     def __len__(self):
     def __len__(self):
         return len(self.block_uids)
         return len(self.block_uids)
 
 
-    def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
+    def update_(self, new_block_infos: List[RemoteModuleInfo]):
         assert len(new_block_infos) == len(self.block_uids)
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.debug(f"Found no block info for block {uid}")
-                continue
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-                continue
-            if not info.servers:
-                logger.debug(f"Found no active peers for block {uid}")
-                continue
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-                continue
+            assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
             self.block_infos[block_index].servers = info.servers
             self.block_infos[block_index].servers = info.servers
 
 
-        self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+        self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
         self.last_updated_time = time.perf_counter()
         self.last_updated_time = time.perf_counter()
 
 
     @staticmethod
     @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            if info is not None:
-                for peer_id, server_info in info.servers.items():
-                    if server_info.state != ServerState.ONLINE:
-                        continue
-                    if peer_id not in active_spans:
-                        active_spans[peer_id] = RemoteSpanInfo(
-                            peer_id=peer_id,
-                            start=block_index,
-                            end=block_index + 1,
-                            server_info=server_info,
-                        )
-                    else:  # peer_id in active_spans
-                        active_spans[peer_id].end = block_index + 1
-
-            for peer_id in list(active_spans.keys()):
-                if (
-                    info is None
-                    or peer_id not in info.servers
-                    or info.servers[peer_id].state != ServerState.ONLINE
-                    or block_index == len(block_infos) - 1
-                ):
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans, f"spans: {active_spans}"
-
-        closed_spans.sort(key=lambda span: span.length, reverse=True)
+    def _sort_spans(block_infos: List[RemoteModuleInfo]):
+        spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
+        spans_by_priority.sort(key=lambda span: span.length, reverse=True)
 
 
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
+        spans_containing_block = tuple([] for _ in range(len(block_infos)))
+        for span in spans_by_priority:
             for block_index in range(span.start, span.end):
             for block_index in range(span.start, span.end):
                 spans_containing_block[block_index].append(span)
                 spans_containing_block[block_index].append(span)
 
 
-        return closed_spans, spans_containing_block
+        return spans_by_priority, spans_containing_block

+ 0 - 4
src/petals/client/routing/sequence_manager.py

@@ -117,7 +117,6 @@ class RemoteSequenceManager:
         if state.sequence_info.last_updated_time is not None:
         if state.sequence_info.last_updated_time is not None:
             assert block_uids == state.sequence_info.block_uids
             assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
             self._thread.ready.set()  # no need to await the first dht fetch
-            self._need_latest_infos = True
 
 
     @staticmethod
     @staticmethod
     def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
     def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
@@ -346,9 +345,6 @@ class RemoteSequenceManager:
         )
         )
 
 
         for block_info in new_block_infos:
         for block_info in new_block_infos:
-            if not block_info:
-                continue
-
             # Apply allow and block lists
             # Apply allow and block lists
             block_info.servers = {
             block_info.servers = {
                 peer_id: server_info
                 peer_id: server_info

+ 26 - 9
src/petals/data_structures.py

@@ -11,18 +11,15 @@ UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 
 
 
 
-class ServerState(Enum):
-    OFFLINE = 0
-    JOINING = 1
-    ONLINE = 2
-
-
-RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
+    assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
+    dht_prefix, index = uid.split(UID_DELIMITER)
+    return dht_prefix, int(index)
 
 
 
 
 @pydantic.dataclasses.dataclass
 @pydantic.dataclasses.dataclass
 class ModelInfo:
 class ModelInfo:
-    num_blocks: int
+    num_blocks: pydantic.conint(ge=1, strict=True)
     repository: Optional[str] = None
     repository: Optional[str] = None
 
 
     def to_dict(self) -> dict:
     def to_dict(self) -> dict:
@@ -33,11 +30,23 @@ class ModelInfo:
         return cls(**source)
         return cls(**source)
 
 
 
 
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+
+
 @pydantic.dataclasses.dataclass
 @pydantic.dataclasses.dataclass
 class ServerInfo:
 class ServerInfo:
     state: ServerState
     state: ServerState
     throughput: RPS
     throughput: RPS
 
 
+    start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
+    end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
+
     public_name: Optional[str] = None
     public_name: Optional[str] = None
     version: Optional[str] = None
     version: Optional[str] = None
 
 
@@ -83,9 +92,17 @@ class RemoteSpanInfo:
     server_info: ServerInfo
     server_info: ServerInfo
 
 
     @property
     @property
-    def length(self):
+    def length(self) -> int:
         return self.end - self.start
         return self.end - self.start
 
 
+    @property
+    def state(self) -> ServerState:
+        return self.server_info.state
+
+    @property
+    def throughput(self) -> float:
+        return self.server_info.throughput
+
 
 
 RPCInfo = Dict[str, Any]
 RPCInfo = Dict[str, Any]
 
 

+ 3 - 0
src/petals/models/falcon/config.py

@@ -31,6 +31,9 @@ class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig,
     def from_pretrained(
     def from_pretrained(
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
     ):
+        if "180B" in model_name_or_path.upper():
+            logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
+
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         if loading_from_repo and dht_prefix is None:
         if loading_from_repo and dht_prefix is None:
             dht_prefix = str(model_name_or_path)
             dht_prefix = str(model_name_or_path)

+ 4 - 0
src/petals/models/falcon/model.py

@@ -47,6 +47,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         input_ids: Optional[torch.LongTensor] = None,
         input_ids: Optional[torch.LongTensor] = None,
         past_key_values: Optional[RemotePastKeyValues] = None,
         past_key_values: Optional[RemotePastKeyValues] = None,
         attention_mask: Optional[torch.Tensor] = None,
         attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
         head_mask: Optional[torch.LongTensor] = None,
         head_mask: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.LongTensor] = None,
         use_cache: Optional[bool] = None,
         use_cache: Optional[bool] = None,
@@ -68,6 +69,9 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
         assert (
         assert (
             attention_mask is None or (attention_mask == 1).all()
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert (
+            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
+        ), f"Non-consecutive position_ids are not supported, {position_ids=}"
         assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
         assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
         assert use_cache is None or use_cache, f"{use_cache=} is not supported"
         assert use_cache is None or use_cache, f"{use_cache=} is not supported"
         assert not output_attentions, f"{output_attentions=} is not supported"
         assert not output_attentions, f"{output_attentions=} is not supported"

+ 209 - 10
src/petals/models/llama/block.py

@@ -3,13 +3,219 @@ LLaMA intermediate layer
 Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
 Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
 See commit history for authorship.
 See commit history for authorship.
 """
 """
+import math
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
 import torch
 import torch
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    LlamaConfig,
+    LlamaDecoderLayer,
+    LlamaMLP,
+    LlamaModel,
+    LlamaRMSNorm,
+    repeat_kv,
+    rotate_half,
+)
 
 
+from petals.utils.cuda_graphs import make_inference_graphed_callable
+
+
+def apply_rotary_pos_emb(q, k, cos, sin):
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class OptimizedLlamaAttention(LlamaAttention):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._rotary_graph = None
+
+    def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
+        if self._rotary_graph is None:
+            self._rotary_graph = make_inference_graphed_callable(
+                apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
+            )
+        return self._rotary_graph(query_states, key_states, cos, sin)
 
 
-class WrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        assert not output_attentions
+        assert position_ids is None
+        bsz, q_len, _ = hidden_states.size()
+
+        if self.config.pretraining_tp > 1:
+            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+            query_slices = self.q_proj.weight.split(
+                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+            )
+            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+            query_states = torch.cat(query_states, dim=-1)
+
+            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+            key_states = torch.cat(key_states, dim=-1)
+
+            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+            value_states = torch.cat(value_states, dim=-1)
+
+        else:
+            query_states = self.q_proj(hidden_states)
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        cos = cos[:, :, kv_seq_len - q_len :]
+        sin = sin[:, :, kv_seq_len - q_len :]
+
+        if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
+        else:
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:
+            attn_weights = attn_weights + attention_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        if self.config.pretraining_tp > 1:
+            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+        else:
+            attn_output = self.o_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
+    def __init__(self, config: LlamaConfig):
+        nn.Module.__init__(self)
+        self.hidden_size = config.hidden_size
+        self.self_attn = OptimizedLlamaAttention(config=config)
+        self.mlp = LlamaMLP(config)
+        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+        self.pre_attn_graph = None
+        self.post_attn_graph = None
+
+    def _optimized_input_layernorm(self, hidden_states):
+        if self.pre_attn_graph is None:
+            self.pre_attn_graph = make_inference_graphed_callable(
+                self.input_layernorm.forward, sample_args=(hidden_states,)
+            )
+        return self.pre_attn_graph(hidden_states)
+
+    def _optimized_output_layernorm(self, hidden_states):
+        if self.post_attn_graph is None:
+            self.post_attn_graph = make_inference_graphed_callable(
+                self.post_attention_layernorm.forward, sample_args=(hidden_states,)
+            )
+        return self.post_attn_graph(hidden_states)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+        """
+
+        residual = hidden_states
+
+        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            hidden_states = self._optimized_input_layernorm(hidden_states)
+        else:
+            hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+
+        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            hidden_states = self._optimized_output_layernorm(hidden_states)
+        else:
+            hidden_states = self.post_attention_layernorm(hidden_states)
+
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
     def forward(
     def forward(
         self,
         self,
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
@@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
             seq_length_with_past = seq_length_with_past + past_key_values_length
             seq_length_with_past = seq_length_with_past + past_key_values_length
             past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
             past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
 
 
-        if position_ids is None:
-            device = hidden_states.device
-            position_ids = torch.arange(
-                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
-            )
-            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
-        else:
-            position_ids = position_ids.view(-1, seq_length).long()
+        assert position_ids is None
 
 
         # embed positions
         # embed positions
         if attention_mask is None:
         if attention_mask is None:

+ 24 - 48
src/petals/server/block_selection.py

@@ -1,54 +1,23 @@
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List
 
 
 import numpy as np
 import numpy as np
 from hivemind import PeerID, get_logger
 from hivemind import PeerID, get_logger
 
 
-from petals.data_structures import RemoteModuleInfo, ServerState
-
-__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
+from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
+from petals.utils.dht import compute_spans
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-@dataclass
-class Span:
-    start: int
-    end: int
-    throughput: float
-    state: ServerState
-
-    @property
-    def length(self):
-        return self.end - self.start
-
-    def move_to(self, new_start: int) -> None:
-        self.start, self.end = new_start, new_start + self.length
-
-
-def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
-    spans = {}
-    throughputs = np.zeros(len(module_infos))
-    for block, module in enumerate(module_infos):
-        if module is None:
-            continue
-
-        # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
-        # If the order were not defined, we would get slightly different values due to floating point errors,
-        # which may cause excess block replacements.
-        for peer_id, server in sorted(module.servers.items()):
-            if server.state == ServerState.OFFLINE:
-                continue
+def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
+    # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
+    # If the order were not defined, we would get slightly different values due to floating point errors,
+    # which may cause excess block replacements.
 
 
-            if peer_id in spans:
-                spans[peer_id].start = min(spans[peer_id].start, block)
-                spans[peer_id].end = max(spans[peer_id].start, block + 1)
-            else:
-                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
-
-            throughputs[block] += server.throughput
-
-    return spans, throughputs
+    throughputs = np.zeros(total_blocks)
+    for span in sorted(spans.values(), key=lambda span: span.peer_id):
+        throughputs[span.start : span.end] += span.throughput
+    return throughputs
 
 
 
 
 def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
 def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
@@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
     return min(options)[-1]
     return min(options)[-1]
 
 
 
 
-def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
-    _, throughputs = compute_spans(module_infos)
+def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
+    spans = compute_spans(module_infos, min_state=ServerState.JOINING)
+    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
+
     start = _choose_best_start(throughputs, num_blocks)
     start = _choose_best_start(throughputs, num_blocks)
     return list(range(start, start + num_blocks))
     return list(range(start, start + num_blocks))
 
 
 
 
+def _move_span(span: RemoteSpanInfo, new_start: int):
+    span.start, span.end = new_start, new_start + span.length
+
+
 def should_choose_other_blocks(
 def should_choose_other_blocks(
-    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
+    local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
 ) -> bool:
 ) -> bool:
     if balance_quality > 1.0:
     if balance_quality > 1.0:
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
 
 
-    spans, throughputs = compute_spans(module_infos)
+    spans = compute_spans(module_infos, min_state=ServerState.JOINING)
+    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
     initial_throughput = throughputs.min()
     initial_throughput = throughputs.min()
     eps = 1e-3
     eps = 1e-3
 
 
@@ -88,7 +64,7 @@ def should_choose_other_blocks(
         return False  # This server is on its best place already
         return False  # This server is on its best place already
 
 
     throughputs[local_span.start : local_span.end] += local_span.throughput * eps
     throughputs[local_span.start : local_span.end] += local_span.throughput * eps
-    local_span.move_to(new_start)
+    _move_span(local_span, new_start)
     throughputs[local_span.start : local_span.end] += local_span.throughput
     throughputs[local_span.start : local_span.end] += local_span.throughput
 
 
     moved = True
     moved = True
@@ -105,7 +81,7 @@ def should_choose_other_blocks(
 
 
             throughputs[span.start : span.end] += span.throughput * eps
             throughputs[span.start : span.end] += span.throughput * eps
             if span.start != new_start:
             if span.start != new_start:
-                span.move_to(new_start)
+                _move_span(span, new_start)
                 moved = True
                 moved = True
             throughputs[span.start : span.end] += span.throughput
             throughputs[span.start : span.end] += span.throughput
 
 

+ 16 - 13
src/petals/server/server.py

@@ -24,7 +24,7 @@ from transformers import PretrainedConfig
 
 
 import petals
 import petals
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
 from petals.server import block_selection
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
 from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -204,7 +204,7 @@ class Server:
 
 
         # For attention cache in GPU or RAM
         # For attention cache in GPU or RAM
         if attn_cache_tokens is None:
         if attn_cache_tokens is None:
-            attn_cache_tokens = 32768 if is_multiquery_attn else 8192
+            attn_cache_tokens = 16384 if is_multiquery_attn else 4096
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         cache_values_per_block //= self.block_config.num_key_value_groups
         cache_values_per_block //= self.block_config.num_key_value_groups
         self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
         self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
@@ -221,11 +221,10 @@ class Server:
             num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
             num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
         if block_indices is not None:
         if block_indices is not None:
             try:
             try:
-                first_block_index, last_block_index = block_indices.split(":")
-                first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
+                start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
             except Exception as e:
             except Exception as e:
                 raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
                 raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
-            block_indices = range(first_block_index, last_block_index)
+            block_indices = range(start_block, end_block)
             num_blocks = len(block_indices)
             num_blocks = len(block_indices)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
 
 
@@ -704,11 +703,16 @@ class ModuleAnnouncerThread(threading.Thread):
         self.expiration = expiration
         self.expiration = expiration
         self.trigger = threading.Event()
         self.trigger = threading.Event()
 
 
+        self.dht_prefix = parse_uid(module_uids[0])[0]
+        block_indices = [parse_uid(uid)[1] for uid in module_uids]
+        self.server_info.start_block = min(block_indices)
+        self.server_info.end_block = max(block_indices) + 1
+
         self.max_pinged = max_pinged
         self.max_pinged = max_pinged
-        self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
-        block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
-        start_block, end_block = min(block_indices), max(block_indices) + 1
-        self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
+        self.next_uids = [
+            f"{self.dht_prefix}{UID_DELIMITER}{i}"
+            for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
+        ]
         self.ping_aggregator = PingAggregator(self.dht)
         self.ping_aggregator = PingAggregator(self.dht)
 
 
     def run(self) -> None:
     def run(self) -> None:
@@ -756,12 +760,11 @@ class ModuleAnnouncerThread(threading.Thread):
 
 
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
     def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
         module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
         module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
-        middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
+        middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
         pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
         pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
         pinged_servers.discard(self.dht.peer_id)
         pinged_servers.discard(self.dht.peer_id)
-        if module_infos[-1] is not None:
-            # Sample servers hosting the block after the last one (most likely continuations) separately
-            pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
+        # Sample servers hosting the block after the last one (most likely continuations) separately
+        pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
         self.ping_aggregator.ping(list(pinged_servers))
         self.ping_aggregator.ping(list(pinged_servers))
 
 
 
 

+ 1 - 1
src/petals/server/throughput.py

@@ -56,7 +56,7 @@ def get_server_throughput(
 
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     os.makedirs(lock_path.parent, exist_ok=True)
-    with open(lock_path, "wb") as lock_fd:
+    with open(lock_path, "wb+") as lock_fd:
         logger.info("Loading throughput info")
         logger.info("Loading throughput info")
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
         # The OS will release the lock when lock_fd is closed or the process is killed

+ 76 - 0
src/petals/utils/cuda_graphs.py

@@ -0,0 +1,76 @@
+import torch
+from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
+
+
+def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
+    """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
+    assert not isinstance(callable, torch.nn.Module)
+    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
+        raise RuntimeError(
+            "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
+        )
+
+    flatten_arg, _ = _tree_flatten(sample_args)
+    flatten_sample_args = tuple(flatten_arg)
+    assert all(
+        isinstance(arg, torch.Tensor) for arg in flatten_arg
+    ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
+
+    len_user_args = len(sample_args)
+    static_input_surface = flatten_sample_args
+
+    graph = torch.cuda.CUDAGraph()
+
+    # Warmup
+    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
+    # from ending up in any captures.
+    s = torch.cuda.Stream()
+    s.wait_stream(torch.cuda.current_stream())
+    with torch.cuda.stream(s):
+        for _ in range(num_warmup_iters):
+            outputs, _ = _tree_flatten(callable(*sample_args))
+        del outputs
+    torch.cuda.current_stream().wait_stream(s)
+
+    # Capture forward graph
+    with torch.cuda.graph(graph):
+        outputs = callable(*sample_args)
+
+    flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
+    static_outputs = tuple(flatten_outputs)
+
+    def make_graphed_function(
+        graph,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+    ):
+        def replay_graph(*inputs):
+            # At this stage, only the user args may (potentially) be new tensors.
+            for i in range(len_user_args):
+                if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
+                    static_input_surface[i].copy_(inputs[i])
+            graph.replay()
+            assert isinstance(static_outputs, tuple)
+            return tuple(o.detach() for o in static_outputs)
+
+        def functionalized(*user_args):
+            # Runs the autograd function with inputs == all inputs to the graph that might require grad
+            # (explicit user args + module parameters)
+            # Assumes module params didn't change since capture.
+            flatten_user_args, _ = _tree_flatten(user_args)
+            out = replay_graph(*flatten_user_args)
+            return _tree_unflatten(out, output_unflatten_spec)
+
+        return functionalized
+
+    # Put together the final graphed callable
+    graphed = make_graphed_function(
+        graph,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+    )
+    return graphed

+ 41 - 12
src/petals/utils/dht.py

@@ -11,7 +11,16 @@ from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
 
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
+from petals.data_structures import (
+    CHAIN_DELIMITER,
+    UID_DELIMITER,
+    ModuleUID,
+    RemoteModuleInfo,
+    RemoteSpanInfo,
+    ServerInfo,
+    ServerState,
+    parse_uid,
+)
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -70,7 +79,7 @@ def get_remote_module_infos(
     *,
     *,
     latest: bool = False,
     latest: bool = False,
     return_future: bool = False,
     return_future: bool = False,
-) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
+) -> Union[List[RemoteModuleInfo], MPFuture]:
     return dht.run_coroutine(
     return dht.run_coroutine(
         partial(
         partial(
             _get_remote_module_infos,
             _get_remote_module_infos,
@@ -90,7 +99,7 @@ async def _get_remote_module_infos(
     active_adapter: Optional[str],
     active_adapter: Optional[str],
     expiration_time: Optional[DHTExpiration],
     expiration_time: Optional[DHTExpiration],
     latest: bool,
     latest: bool,
-) -> List[Optional[RemoteModuleInfo]]:
+) -> List[RemoteModuleInfo]:
     if latest:
     if latest:
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
         expiration_time = math.inf
         expiration_time = math.inf
@@ -99,14 +108,14 @@ async def _get_remote_module_infos(
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
     found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
 
 
-    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
-    for i, uid in enumerate(uids):
-        metadata = found[uid]
+    modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
+    for module_info in modules:
+        metadata = found[module_info.uid]
         if metadata is None or not isinstance(metadata.value, dict):
         if metadata is None or not isinstance(metadata.value, dict):
             if metadata is not None:
             if metadata is not None:
-                logger.warning(f"Incorrect metadata for {uid}: {metadata}")
+                logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
             continue
             continue
-        servers = {}
+
         for peer_id, server_info in metadata.value.items():
         for peer_id, server_info in metadata.value.items():
             try:
             try:
                 peer_id = PeerID.from_base58(peer_id)
                 peer_id = PeerID.from_base58(peer_id)
@@ -116,9 +125,29 @@ async def _get_remote_module_infos(
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
                     continue
                     continue
 
 
-                servers[peer_id] = server_info
+                module_info.servers[peer_id] = server_info
             except (TypeError, ValueError) as e:
             except (TypeError, ValueError) as e:
-                logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
-        if servers:
-            modules[i] = RemoteModuleInfo(uid, servers)
+                logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
     return modules
     return modules
+
+
+def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
+    block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
+    num_blocks = len(module_infos)
+
+    spans = {}
+    for block_idx, module_info in enumerate(module_infos):
+        for peer_id, server_info in sorted(module_info.servers.items()):
+            if server_info.state.value < min_state.value:
+                continue
+
+            if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
+                spans[peer_id] = RemoteSpanInfo(
+                    peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
+                )
+                if server_info.start_block is not None and server_info.end_block is not None:
+                    spans[peer_id].start = max(server_info.start_block - block_offset, 0)
+                    spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
+            elif spans[peer_id].state == server_info.state:
+                spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
+    return spans

+ 1 - 1
src/petals/utils/disk_cache.py

@@ -22,7 +22,7 @@ def _blocks_lock(cache_dir: Optional[str], mode: int):
     lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
     lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)
 
 
     os.makedirs(lock_path.parent, exist_ok=True)
     os.makedirs(lock_path.parent, exist_ok=True)
-    with open(lock_path, "wb") as lock_fd:
+    with open(lock_path, "wb+") as lock_fd:
         fcntl.flock(lock_fd.fileno(), mode)
         fcntl.flock(lock_fd.fileno(), mode)
         # The OS will release the lock when lock_fd is closed or the process is killed
         # The OS will release the lock when lock_fd is closed or the process is killed
         yield
         yield

+ 93 - 5
tests/test_optimized_layers.py

@@ -3,6 +3,7 @@ from typing import Optional, Tuple
 import pytest
 import pytest
 import torch
 import torch
 from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
 from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
 
 
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, convert_block
 from petals.utils.convert_block import QuantType, convert_block
@@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
         return state
         return state
 
 
 
 
-@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
+class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        batch_size, seq_length, _ = hidden_states.shape
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        past_key_value = layer_past
+        if past_key_value is not None:
+            past_key_values_length = past_key_value[0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
+
+        if position_ids is None:
+            device = hidden_states.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+            )
+        attention_mask = LlamaModel._prepare_decoder_attention_mask(
+            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        )
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_llama_to_bloom(
+                present_key_value, batch_size, seq_length_with_past
+            )
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_llama(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        key_states = key_states.permute(0, 2, 1)
+        key_states = key_states.view(
+            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        value_states = value_states.view(*key_states.shape)
+        return (key_states, value_states)
+
+    def _reorder_cache_from_llama_to_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        value_states = value_states.view(
+            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        key_states = key_states.view(*value_states.shape)
+        key_states = key_states.permute(0, 2, 1)
+        return (key_states, value_states)
+
+
 @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
 @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
 @pytest.mark.forked
 @pytest.mark.forked
-def test_falcon(device):
+def test_optimized_block(device):
     if device == "cuda:0" and not torch.cuda.is_available():
     if device == "cuda:0" and not torch.cuda.is_available():
         pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
         pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
 
 
@@ -108,11 +190,17 @@ def test_falcon(device):
     quant_type = QuantType.NONE
     quant_type = QuantType.NONE
 
 
     block = config.block_class(config).to(dtype)
     block = config.block_class(config).to(dtype)
-    block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+    block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+
+    if config.model_type == "falcon":
+        unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
+    elif config.model_type == "llama":
+        unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
+    else:
+        pytest.skip(f"This test is not applicable to {config.model_type} models")
 
 
-    unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
     unopt_block = convert_block(
     unopt_block = convert_block(
-        unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
+        unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
     )
     )
 
 
     unopt_block.load_state_dict(block.state_dict())
     unopt_block.load_state_dict(block.state_dict())