Explorar el Código

Merge branch 'main' into speculative_inference

justheuristic hace 11 meses
padre
commit
13111911a6

+ 37 - 18
README.md

@@ -8,14 +8,14 @@
     <br>
     <br>
 </p>
 </p>
 
 
-Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **Llama 3.1** (up to 405B), **Mixtral** (8x22B), **Falcon** (40B+) or **BLOOM** (176B) and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 
 ```python
 ```python
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from petals import AutoDistributedModelForCausalLM
 from petals import AutoDistributedModelForCausalLM
 
 
 # Choose any model available at https://health.petals.dev
 # Choose any model available at https://health.petals.dev
-model_name = "petals-team/StableBeluga2"  # This one is fine-tuned Llama 2 (70B)
+model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct"
 
 
 # Connect to a distributed network hosting model layers
 # Connect to a distributed network hosting model layers
 tokenizer = AutoTokenizer.from_pretrained(model_name)
 tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -31,22 +31,26 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
 </p>
 </p>
 
 
-🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
+🦙 **Want to run Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
 
 
-🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
+🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
 
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
 
 ## Connect your GPU and increase Petals capacity
 ## Connect your GPU and increase Petals capacity
 
 
-Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU:
+Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can help serving one of the [available models](https://health.petals.dev) or host a new model from 🤗 [Model Hub](https://huggingface.co/models)!
+
+As an example, here is how to host a part of [Llama 3.1 (405B) Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on your GPU:
+
+🦙 **Want to host Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model.
 
 
 🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 
 ```bash
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 pip install git+https://github.com/bigscience-workshop/petals
 pip install git+https://github.com/bigscience-workshop/petals
-python -m petals.cli.run_server petals-team/StableBeluga2
+python -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
 ```
 ```
 
 
 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
@@ -56,7 +60,7 @@ python -m petals.cli.run_server petals-team/StableBeluga2
 ```bash
 ```bash
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
     learningathome/petals:main \
     learningathome/petals:main \
-    python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
+    python -m petals.cli.run_server --port 31330 meta-llama/Meta-Llama-3.1-405B-Instruct
 ```
 ```
 
 
 🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
 🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
@@ -64,19 +68,17 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach
 ```bash
 ```bash
 brew install python
 brew install python
 python3 -m pip install git+https://github.com/bigscience-workshop/petals
 python3 -m pip install git+https://github.com/bigscience-workshop/petals
-python3 -m petals.cli.run_server petals-team/StableBeluga2
+python3 -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
 ```
 ```
 
 
 <p align="center">
 <p align="center">
     📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
     📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
 </p>
 </p>
 
 
-💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
-
-🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
-
 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 
 
+💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
+
 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
 
 
 ## How does it work?
 ## How does it work?
@@ -120,22 +122,39 @@ Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
 
 
 Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
 Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
 
 
-### 📜 Citation
+### 📜 Citations
 
 
 Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
 Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
 [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
 [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
-_arXiv preprint arXiv:2209.01188,_ 2022.
+_Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)._ 2023.
 
 
 ```bibtex
 ```bibtex
-@article{borzunov2022petals,
+@inproceedings{borzunov2023petals,
   title = {Petals: Collaborative Inference and Fine-tuning of Large Models},
   title = {Petals: Collaborative Inference and Fine-tuning of Large Models},
-  author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Ryabinin, Max and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
-  journal = {arXiv preprint arXiv:2209.01188},
-  year = {2022},
+  author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Riabinin, Maksim and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
+  booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+  pages = {558--568},
+  year = {2023},
   url = {https://arxiv.org/abs/2209.01188}
   url = {https://arxiv.org/abs/2209.01188}
 }
 }
 ```
 ```
 
 
+Alexander Borzunov, Max Ryabinin, Artem Chumachenko, Dmitry Baranchuk, Tim Dettmers, Younes Belkada, Pavel Samygin, and Colin Raffel.
+[Distributed inference and fine-tuning of large language models over the Internet.](https://arxiv.org/abs/2312.08361)
+_Advances in Neural Information Processing Systems_ 36 (2023).
+
+```bibtex
+@inproceedings{borzunov2023distributed,
+  title = {Distributed inference and fine-tuning of large language models over the {I}nternet},
+  author = {Borzunov, Alexander and Ryabinin, Max and Chumachenko, Artem and Baranchuk, Dmitry and Dettmers, Tim and Belkada, Younes and Samygin, Pavel and Raffel, Colin},
+  booktitle = {Advances in Neural Information Processing Systems},
+  volume = {36},
+  pages = {12312--12331},
+  year = {2023},
+  url = {https://arxiv.org/abs/2312.08361}
+}
+```
+
 --------------------------------------------------------------------------------
 --------------------------------------------------------------------------------
 
 
 <p align="center">
 <p align="center">

+ 4 - 5
setup.cfg

@@ -32,22 +32,21 @@ package_dir =
 packages = find:
 packages = find:
 python_requires = >=3.8
 python_requires = >=3.8
 install_requires =
 install_requires =
-    torch>=1.12,<2.3.0
+    torch>=1.12
     bitsandbytes==0.41.1
     bitsandbytes==0.41.1
     accelerate>=0.27.2
     accelerate>=0.27.2
     huggingface-hub>=0.11.1,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
     tokenizers>=0.13.3
-    transformers==4.41.2  # if you change this, please also change version assert in petals/__init__.py
+    transformers==4.43.1  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
-    pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
-    hivemind==1.1.10.post2
+    hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff98a62accb91f254e2afdccbf1d69ebdea9
     tensor_parallel==1.0.23
     tensor_parallel==1.0.23
     humanfriendly
     humanfriendly
     async-timeout>=4.0.2
     async-timeout>=4.0.2
     cpufeature>=0.2.0; platform_machine == "x86_64"
     cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     packaging>=20.9
     sentencepiece>=0.1.99
     sentencepiece>=0.1.99
-    peft==0.5.0
+    peft==0.8.2
     safetensors>=0.3.1
     safetensors>=0.3.1
     Dijkstar>=2.6.0
     Dijkstar>=2.6.0
     numpy<2
     numpy<2

+ 2 - 2
src/petals/__init__.py

@@ -22,8 +22,8 @@ __version__ = "2.3.0.dev2"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
     assert (
-        version.parse("4.41.2") <= version.parse(transformers.__version__) < version.parse("4.42.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.41.2,<4.42.0"
+        version.parse("4.43.1") <= version.parse(transformers.__version__) < version.parse("4.44.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.43.1,<4.44.0"
 
 
 
 
 def _override_bfloat16_mode_default():
 def _override_bfloat16_mode_default():

+ 1 - 1
src/petals/client/inference_session.py

@@ -336,7 +336,7 @@ class InferenceSession:
                         self._update_sequence(server_idx, block_idx, attempt_no)
                         self._update_sequence(server_idx, block_idx, attempt_no)
 
 
                     server_session = self._server_sessions[server_idx]
                     server_session = self._server_sessions[server_idx]
-                    assert server_session.position == self.position
+                    assert server_session.position == self.position, f"Position mismatch: {server_session.position} and {self.position}"
                     inputs = server_session.step(
                     inputs = server_session.step(
                         inputs,
                         inputs,
                         prompts[server_session.span.start : server_session.span.end],
                         prompts[server_session.span.start : server_session.span.end],

+ 1 - 1
src/petals/data_structures.py

@@ -2,7 +2,7 @@ import dataclasses
 from enum import Enum
 from enum import Enum
 from typing import Any, Dict, Optional, Sequence, Tuple
 from typing import Any, Dict, Optional, Sequence, Tuple
 
 
-import pydantic
+import pydantic.v1 as pydantic
 from hivemind import PeerID
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.expert_uid import ExpertUID
 
 

+ 1 - 1
src/petals/models/bloom/block.py

@@ -7,7 +7,7 @@ from typing import Optional, Tuple
 
 
 import torch
 import torch
 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
+from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor
 
 
 from petals.utils.misc import is_dummy
 from petals.utils.misc import is_dummy
 
 

+ 1 - 1
src/petals/models/bloom/config.py

@@ -24,7 +24,7 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
     def from_pretrained(
     def from_pretrained(
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
     ):
-        logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
+        logger.info("Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license")
 
 
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         if loading_from_repo and dht_prefix is None:
         if loading_from_repo and dht_prefix is None:

+ 2 - 0
src/petals/models/llama/__init__.py

@@ -5,11 +5,13 @@ from petals.models.llama.model import (
     DistributedLlamaForSequenceClassification,
     DistributedLlamaForSequenceClassification,
     DistributedLlamaModel,
     DistributedLlamaModel,
 )
 )
+from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
 from petals.utils.auto_config import register_model_classes
 from petals.utils.auto_config import register_model_classes
 
 
 register_model_classes(
 register_model_classes(
     config=DistributedLlamaConfig,
     config=DistributedLlamaConfig,
     model=DistributedLlamaModel,
     model=DistributedLlamaModel,
     model_for_causal_lm=DistributedLlamaForCausalLM,
     model_for_causal_lm=DistributedLlamaForCausalLM,
+    model_for_speculative=DistributedLlamaForSpeculativeGeneration,
     model_for_sequence_classification=DistributedLlamaForSequenceClassification,
     model_for_sequence_classification=DistributedLlamaForSequenceClassification,
 )
 )

+ 2 - 2
src/petals/models/llama/block.py

@@ -15,7 +15,6 @@ from transformers.models.llama.modeling_llama import (
     LlamaConfig,
     LlamaConfig,
     LlamaDecoderLayer,
     LlamaDecoderLayer,
     LlamaMLP,
     LlamaMLP,
-    LlamaModel,
     LlamaRMSNorm,
     LlamaRMSNorm,
     repeat_kv,
     repeat_kv,
     rotate_half,
     rotate_half,
@@ -132,7 +131,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
     def __init__(self, config: LlamaConfig):
     def __init__(self, config: LlamaConfig):
         nn.Module.__init__(self)
         nn.Module.__init__(self)
         self.hidden_size = config.hidden_size
         self.hidden_size = config.hidden_size
-        self.self_attn = OptimizedLlamaAttention(config=config)
+        self.self_attn = OptimizedLlamaAttention(config=config, layer_idx=0)
+        # layer_idx only matters for KV caching, and we re-implement it in Petals
         self.mlp = LlamaMLP(config)
         self.mlp = LlamaMLP(config)
         self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

+ 2 - 2
src/petals/models/llama/config.py

@@ -27,8 +27,8 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
     ):
         logger.info(
         logger.info(
-            "Make sure you follow the LLaMA's terms of use: "
-            "https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
+            "Make sure you follow the Llama terms of use: "
+            "https://llama.meta.com/llama3/license, https://llama.meta.com/llama2/license"
         )
         )
 
 
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)

+ 111 - 0
src/petals/models/llama/speculative_model.py

@@ -0,0 +1,111 @@
+from typing import Optional, Union
+
+import torch
+from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
+from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama import LlamaForCausalLM
+
+from petals.models.llama.config import DistributedLlamaConfig
+from petals.models.llama.model import DistributedLlamaForCausalLM
+
+
+class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
+    def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
+        DistributedLlamaForCausalLM.__init__(self, config)
+        self.small_model = small_model
+
+    def _sample(
+        self,
+        input_ids: torch.LongTensor,
+        logits_processor: LogitsProcessorList,
+        stopping_criteria: StoppingCriteriaList,
+        generation_config: GenerationConfig,
+        synced_gpus: bool,
+        streamer: Optional["BaseStreamer"],
+        logits_warper: Optional[LogitsProcessorList],
+        speculative_inference_iteration_size: int = 10,
+        **model_kwargs,
+    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+        assert not generation_config.do_sample, "sample is not working for speculative generation now"
+        assert not synced_gpus, "synced_gpus is not working for speculative generation now"
+        assert (
+            not generation_config.return_dict_in_generate
+        ), "return_dict_in_generate is not working for speculative generation now"
+
+        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+
+        # keep track of which sequences are already finished
+        batch_size = input_ids.shape[0]
+        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
+        finished = False
+        firsts = True
+
+        while not finished:
+            speculative_inference_iteration_size = min(
+                speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
+            )
+            with torch.no_grad():
+                speculative_outputs = self.small_model.generate(
+                    input_ids,
+                    max_new_tokens=speculative_inference_iteration_size,
+                    do_sample=False,
+                )
+                speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]
+
+            full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
+            assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]
+
+            input_for_validation = full_sequence
+            if not firsts:
+                self.active_session.position = input_ids.shape[1] - 1
+                input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
+            else:
+                firsts = False
+            input_for_validation = input_for_validation[:, :-1]
+            with torch.no_grad():
+                precise_model_outputs = self(input_for_validation)
+            full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()
+
+            all_valid_tokens = []
+            first_token = None
+            for i in range(speculative_inference_iteration_size):
+                token_logits = full_token_logits[:, i, :]
+                token_scores = logits_processor(
+                    input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
+                )
+                valid_token = torch.argmax(token_scores, dim=-1)
+
+                if first_token is None:
+                    first_token = valid_token
+
+                if valid_token.item() == speculative_tokens[:, i].item():
+                    all_valid_tokens.append(valid_token.unsqueeze(-1))
+                else:
+                    break
+
+            if not all_valid_tokens and first_token is not None:
+                all_valid_tokens.append(first_token.unsqueeze(-1))
+            all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)
+
+            # finished sentences should have their next token be a padding token
+            if has_eos_stopping_criteria:
+                all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
+                    1 - unfinished_sequences
+                )
+
+            # update generated ids, model inputs, and length for next step
+            input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)
+
+            if streamer is not None:
+                streamer.put(all_valid_tokens.cpu())
+
+            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
+            finished = unfinished_sequences.max() == 0
+
+            del precise_model_outputs
+
+        if streamer is not None:
+            streamer.end()
+
+        return input_ids

+ 1 - 2
src/petals/models/mixtral/block.py

@@ -1,4 +1,3 @@
-import json
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
 import torch
 import torch
@@ -8,7 +7,7 @@ from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask,
     _prepare_4d_causal_attention_mask,
     _prepare_4d_causal_attention_mask_for_sdpa,
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
 )
-from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
+from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
 
 
 
 
 class WrappedMixtralBlock(MixtralDecoderLayer):
 class WrappedMixtralBlock(MixtralDecoderLayer):

+ 3 - 0
src/petals/models/mixtral/model.py

@@ -55,6 +55,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
         output_hidden_states: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         output_router_logits: Optional[bool] = None,
         output_router_logits: Optional[bool] = None,
         return_dict: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ):
     ):
         if input_ids is not None and inputs_embeds is not None:
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -70,6 +71,8 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
         assert (
         assert (
             attention_mask is None or (attention_mask == 1).all()
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        if cache_position is not None:
+            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
         assert (
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"

+ 1 - 1
src/petals/server/block_utils.py

@@ -32,7 +32,7 @@ def get_block_size(
             dtype is not None and quant_type is not None
             dtype is not None and quant_type is not None
         ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
         ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
 
 
-    with init_empty_weights(include_buffers=True):
+    with init_empty_weights(include_buffers=False):
         block = get_model_block(config)
         block = get_model_block(config)
         n_params = sum(param.numel() for param in block.parameters())
         n_params = sum(param.numel() for param in block.parameters())
 
 

+ 0 - 5
src/petals/server/from_pretrained.py

@@ -64,10 +64,6 @@ def load_pretrained_block(
         max_disk_space=max_disk_space,
         max_disk_space=max_disk_space,
     )
     )
 
 
-    # dummy load, check that keys match
-    report = block.load_state_dict(state_dict, strict=False)
-    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
-
     for param_name, _ in block.named_parameters():
     for param_name, _ in block.named_parameters():
         assert param_name in state_dict, f"{param_name} not in state dict"
         assert param_name in state_dict, f"{param_name} not in state dict"
         param = state_dict[param_name]
         param = state_dict[param_name]
@@ -76,7 +72,6 @@ def load_pretrained_block(
         set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
         set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 
 
     logger.info(f"Loaded {model_name} block {block_index}")
     logger.info(f"Loaded {model_name} block {block_index}")
-    logger.debug(f"Details: {report}")
     return block
     return block
 
 
 
 

+ 1 - 0
src/petals/utils/__init__.py

@@ -3,5 +3,6 @@ from petals.utils.auto_config import (
     AutoDistributedModel,
     AutoDistributedModel,
     AutoDistributedModelForCausalLM,
     AutoDistributedModelForCausalLM,
     AutoDistributedModelForSequenceClassification,
     AutoDistributedModelForSequenceClassification,
+    AutoDistributedSpeculativeModel,
 )
 )
 from petals.utils.dht import declare_active_modules, get_remote_module_infos
 from petals.utils.dht import declare_active_modules, get_remote_module_infos

+ 5 - 0
src/petals/utils/auto_config.py

@@ -15,6 +15,7 @@ class _ModelClasses:
     config: Type[PretrainedConfig]
     config: Type[PretrainedConfig]
     model: Optional[Type[PreTrainedModel]] = None
     model: Optional[Type[PreTrainedModel]] = None
     model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
     model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
+    model_for_speculative: Optional[Type[PreTrainedModel]] = None
     model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
     model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
 
 
 
 
@@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
     _mapping_field = "model_for_causal_lm"
     _mapping_field = "model_for_causal_lm"
 
 
 
 
+class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
+    _mapping_field = "model_for_speculative"
+
+
 class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
 class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_sequence_classification"
     _mapping_field = "model_for_sequence_classification"

+ 1 - 1
src/petals/utils/convert_block.py

@@ -61,7 +61,7 @@ def convert_block(
     if adapters:
     if adapters:
         from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
         from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
 
 
-        create_lora_adapter(block, quant_type=quant_type)
+        create_lora_adapter(block)
         for adapter_name in adapters:
         for adapter_name in adapters:
             adapter_config, adapter_state_dict = load_peft(
             adapter_config, adapter_state_dict = load_peft(
                 adapter_name,
                 adapter_name,

+ 43 - 48
src/petals/utils/peft.py

@@ -1,7 +1,7 @@
 import contextlib
 import contextlib
 import re
 import re
 import time
 import time
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union
 
 
 import bitsandbytes as bnb
 import bitsandbytes as bnb
 import torch
 import torch
@@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from peft.config import PeftConfig
 from peft.config import PeftConfig
 from peft.tuners import lora
 from peft.tuners import lora
-from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
+from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
 from safetensors import safe_open
 from safetensors import safe_open
 from safetensors.torch import load_file
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 from transformers.utils import get_file_from_repo
@@ -25,6 +25,9 @@ from petals.utils.misc import get_size_in_bytes
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]
+
+
 def check_peft_repository(repo_id: str) -> bool:
 def check_peft_repository(repo_id: str) -> bool:
     return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
     return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
 
 
@@ -151,6 +154,18 @@ class AdapterContextMixin:
     def active_adapter(self, value: Optional[str]):
     def active_adapter(self, value: Optional[str]):
         assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
         assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
 
 
+    @property
+    def active_adapters(self):
+        return [self._context_active_adapter]
+
+    def set_adapter(self, adapter_names) -> None:
+        """
+        In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus,
+        this code removes this functionality.
+        Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463
+        """
+        pass
+
 
 
 using_adapter = AdapterContextMixin.using_adapter
 using_adapter = AdapterContextMixin.using_adapter
 
 
@@ -158,60 +173,39 @@ using_adapter = AdapterContextMixin.using_adapter
 class LoraLinear(AdapterContextMixin, lora.Linear):
 class LoraLinear(AdapterContextMixin, lora.Linear):
     """LoRA linear layer that uses adapter selected via using_adapter"""
     """LoRA linear layer that uses adapter selected via using_adapter"""
 
 
+    def __init__(self, base_layer, adapter_name: str):
+        nn.Module.__init__(self)
+        lora.LoraLayer.__init__(self, base_layer)
+
+        self._active_adapter = adapter_name
+        self.is_target_conv_1d_layer = False
+
 
 
-class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
+class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
     """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
     """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
 
 
 
 
-class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
+class LoraLinear4bit(LoraLinear, lora.Linear4bit):
     """LoRA linear 4-bit that uses adapter selected via using_adapter"""
     """LoRA linear 4-bit that uses adapter selected via using_adapter"""
 
 
 
 
-def create_lora_adapter(block, quant_type: QuantType):
-    for _, module in block.named_modules():
+def create_lora_adapter(block):
+    for module_name, module in block.named_modules():
+        if isinstance(module, LoraLinear):
+            continue
         for child_name, child in module.named_children():
         for child_name, child in module.named_children():
-            lora_wrapped_child = None
-            if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
-                continue
-            if quant_type == QuantType.INT8:
-                kwargs = {
-                    "has_fp16_weights": False,
-                    "threshold": 6.0,
-                    "bias": hasattr(child, "bias") and child.bias is not None,
-                }
-                lora_wrapped_child = LoraLinear8bitLt(
-                    AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    **kwargs,
-                )
-            elif quant_type == QuantType.NF4:
-                kwargs = {
-                    "compress_statistics": True,
-                    "quant_type": "nf4",
-                    "blocksize": 64,
-                    "bias": hasattr(child, "bias") and child.bias is not None,
-                }
-                lora_wrapped_child = LoraLinear4bit(
-                    AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    **kwargs,
-                )
-                lora_wrapped_child.compute_dtype = child.compute_dtype
-            else:
-                bias = hasattr(child, "bias") and child.bias is not None
-                lora_wrapped_child = LoraLinear(
+            lora_class = None
+            if isinstance(child, nn.Linear):
+                lora_class = LoraLinear
+            elif isinstance(child, bnb.nn.Linear8bitLt):
+                lora_class = LoraLinear8bitLt
+            elif isinstance(child, bnb.nn.Linear4bit):
+                lora_class = LoraLinear4bit
+            if lora_class:
+                lora_wrapped_child = lora_class(
+                    child,
                     AdapterContextMixin.ADAPTER_NOT_SET,
                     AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    bias=bias,
                 )
                 )
-            if lora_wrapped_child:
-                lora_wrapped_child.weight = child.weight
-                lora_wrapped_child.bias = child.bias
-                for p in lora_wrapped_child.parameters():
-                    p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)
                 setattr(module, child_name, lora_wrapped_child)
 
 
 
 
@@ -240,6 +234,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                             adapter_name,
                             adapter_name,
                             peft_config["r"],
                             peft_config["r"],
                             peft_config["lora_alpha"],
                             peft_config["lora_alpha"],
+                            use_rslora=peft_config.get("use_rslora", False),
                             lora_dropout=peft_config["lora_dropout"],
                             lora_dropout=peft_config["lora_dropout"],
                             init_lora_weights=peft_config["init_lora_weights"],
                             init_lora_weights=peft_config["init_lora_weights"],
                         )
                         )
@@ -272,10 +267,10 @@ def estimate_adapter_memory_per_block(
     **load_peft_kwargs,
     **load_peft_kwargs,
 ) -> int:
 ) -> int:
     """Get the number of extra bytes used to store a set of adapters per given block"""
     """Get the number of extra bytes used to store a set of adapters per given block"""
-    with init_empty_weights(include_buffers=True):
+    with init_empty_weights(include_buffers=False):
         block = get_model_block(block_config)
         block = get_model_block(block_config)
         base_block_parameters = sum(p.numel() for p in block.parameters())
         base_block_parameters = sum(p.numel() for p in block.parameters())
-        create_lora_adapter(block, quant_type=QuantType.NONE)
+        create_lora_adapter(block)
 
 
         for adapter in adapters:
         for adapter in adapters:
             peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
             peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)

+ 50 - 2
tests/test_speculative_generation.py

@@ -2,8 +2,14 @@ import random
 
 
 import pytest
 import pytest
 import torch
 import torch
+import transformers
 
 
-from petals import AutoDistributedConfig, RemoteSequential
+from petals import (
+    AutoDistributedConfig,
+    AutoDistributedSpeculativeModel,
+    DistributedLlamaForSpeculativeGeneration,
+    RemoteSequential,
+)
 from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
 from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
 from petals.server.from_pretrained import load_pretrained_block
 from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 from test_utils import *
@@ -26,7 +32,6 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     with torch.inference_mode():
     with torch.inference_mode():
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             initial_outputs_inference = sess.step(inputs)
             initial_outputs_inference = sess.step(inputs)
-
             sess.position = 2
             sess.position = 2
             secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
             secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
@@ -35,3 +40,46 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     (outputs_local,) = ref_block(short_inputs)
     (outputs_local,) = ref_block(short_inputs)
 
 
     assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
     assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
+
+
+@pytest.fixture
+def noisy_model():
+    noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
+        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    lm_head = noisy_model.get_output_embeddings()
+    assert isinstance(lm_head, torch.nn.Linear)
+    with torch.no_grad():
+        lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
+    return noisy_model
+
+
+@pytest.fixture
+def model():
+    return transformers.AutoModelForCausalLM.from_pretrained(
+        MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+
+
+@pytest.fixture
+def tokenizer():
+    # We set use_fast=False since LlamaTokenizerFast is slow on load
+    return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
+
+
+@pytest.mark.forked
+@pytest.mark.skipif(
+    "llama" not in MODEL_NAME.lower(),
+    reason="Speculative generation now works only for llama models",
+)
+def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
+    speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
+    )
+
+    inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+
+    generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
+    generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)
+
+    assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)