Ver código fonte

Merge branch 'main' into speculative_inference

justheuristic 11 meses atrás
pai
commit
13111911a6

+ 37 - 18
README.md

@@ -8,14 +8,14 @@
     <br>
 </p>
 
-Generate text with distributed **Llama 2** (70B), **Falcon** (40B+), **BLOOM** (176B) (or their derivatives), and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
+Generate text with distributed **Llama 3.1** (up to 405B), **Mixtral** (8x22B), **Falcon** (40B+) or **BLOOM** (176B) and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:
 
 ```python
 from transformers import AutoTokenizer
 from petals import AutoDistributedModelForCausalLM
 
 # Choose any model available at https://health.petals.dev
-model_name = "petals-team/StableBeluga2"  # This one is fine-tuned Llama 2 (70B)
+model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct"
 
 # Connect to a distributed network hosting model layers
 tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -31,22 +31,26 @@ print(tokenizer.decode(outputs[0]))  # A cat sat on a mat...
     🚀 &nbsp;<b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
 </p>
 
-🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
+🦙 **Want to run Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
 
-🦙 **Want to run Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
+🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
 
 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
 
 ## Connect your GPU and increase Petals capacity
 
-Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can check out [available models](https://health.petals.dev) and help serving one of them! As an example, here is how to host a part of [Stable Beluga 2](https://huggingface.co/stabilityai/StableBeluga2) on your GPU:
+Petals is a community-run system &mdash; we rely on people sharing their GPUs. You can help serving one of the [available models](https://health.petals.dev) or host a new model from 🤗 [Model Hub](https://huggingface.co/models)!
+
+As an example, here is how to host a part of [Llama 3.1 (405B) Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on your GPU:
+
+🦙 **Want to host Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model.
 
 🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
 
 ```bash
 conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
 pip install git+https://github.com/bigscience-workshop/petals
-python -m petals.cli.run_server petals-team/StableBeluga2
+python -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
 ```
 
 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
@@ -56,7 +60,7 @@ python -m petals.cli.run_server petals-team/StableBeluga2
 ```bash
 sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
     learningathome/petals:main \
-    python -m petals.cli.run_server --port 31330 petals-team/StableBeluga2
+    python -m petals.cli.run_server --port 31330 meta-llama/Meta-Llama-3.1-405B-Instruct
 ```
 
 🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
@@ -64,19 +68,17 @@ sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cach
 ```bash
 brew install python
 python3 -m pip install git+https://github.com/bigscience-workshop/petals
-python3 -m petals.cli.run_server petals-team/StableBeluga2
+python3 -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
 ```
 
 <p align="center">
     📚 &nbsp;<b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
 </p>
 
-💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
-
-🦙 **Want to host Llama 2?** Request access to its weights at the ♾️ [Meta AI website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) and 🤗 [Model Hub](https://huggingface.co/meta-llama/Llama-2-70b-hf), generate an 🔑 [access token](https://huggingface.co/settings/tokens), then add `--token YOUR_TOKEN_HERE` to the `python -m petals.cli.run_server` command.
-
 🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
 
+💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
+
 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
 
 ## How does it work?
@@ -120,22 +122,39 @@ Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
 
 Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
 
-### 📜 Citation
+### 📜 Citations
 
 Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
 [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
-_arXiv preprint arXiv:2209.01188,_ 2022.
+_Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)._ 2023.
 
 ```bibtex
-@article{borzunov2022petals,
+@inproceedings{borzunov2023petals,
   title = {Petals: Collaborative Inference and Fine-tuning of Large Models},
-  author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Ryabinin, Max and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
-  journal = {arXiv preprint arXiv:2209.01188},
-  year = {2022},
+  author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Riabinin, Maksim and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
+  booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
+  pages = {558--568},
+  year = {2023},
   url = {https://arxiv.org/abs/2209.01188}
 }
 ```
 
+Alexander Borzunov, Max Ryabinin, Artem Chumachenko, Dmitry Baranchuk, Tim Dettmers, Younes Belkada, Pavel Samygin, and Colin Raffel.
+[Distributed inference and fine-tuning of large language models over the Internet.](https://arxiv.org/abs/2312.08361)
+_Advances in Neural Information Processing Systems_ 36 (2023).
+
+```bibtex
+@inproceedings{borzunov2023distributed,
+  title = {Distributed inference and fine-tuning of large language models over the {I}nternet},
+  author = {Borzunov, Alexander and Ryabinin, Max and Chumachenko, Artem and Baranchuk, Dmitry and Dettmers, Tim and Belkada, Younes and Samygin, Pavel and Raffel, Colin},
+  booktitle = {Advances in Neural Information Processing Systems},
+  volume = {36},
+  pages = {12312--12331},
+  year = {2023},
+  url = {https://arxiv.org/abs/2312.08361}
+}
+```
+
 --------------------------------------------------------------------------------
 
 <p align="center">

+ 4 - 5
setup.cfg

@@ -32,22 +32,21 @@ package_dir =
 packages = find:
 python_requires = >=3.8
 install_requires =
-    torch>=1.12,<2.3.0
+    torch>=1.12
     bitsandbytes==0.41.1
     accelerate>=0.27.2
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers==4.41.2  # if you change this, please also change version assert in petals/__init__.py
+    transformers==4.43.1  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
-    pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
-    hivemind==1.1.10.post2
+    hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff98a62accb91f254e2afdccbf1d69ebdea9
     tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2
     cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft==0.5.0
+    peft==0.8.2
     safetensors>=0.3.1
     Dijkstar>=2.6.0
     numpy<2

+ 2 - 2
src/petals/__init__.py

@@ -22,8 +22,8 @@ __version__ = "2.3.0.dev2"
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.41.2") <= version.parse(transformers.__version__) < version.parse("4.42.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.41.2,<4.42.0"
+        version.parse("4.43.1") <= version.parse(transformers.__version__) < version.parse("4.44.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.43.1,<4.44.0"
 
 
 def _override_bfloat16_mode_default():

+ 1 - 1
src/petals/client/inference_session.py

@@ -336,7 +336,7 @@ class InferenceSession:
                         self._update_sequence(server_idx, block_idx, attempt_no)
 
                     server_session = self._server_sessions[server_idx]
-                    assert server_session.position == self.position
+                    assert server_session.position == self.position, f"Position mismatch: {server_session.position} and {self.position}"
                     inputs = server_session.step(
                         inputs,
                         prompts[server_session.span.start : server_session.span.end],

+ 1 - 1
src/petals/data_structures.py

@@ -2,7 +2,7 @@ import dataclasses
 from enum import Enum
 from typing import Any, Dict, Optional, Sequence, Tuple
 
-import pydantic
+import pydantic.v1 as pydantic
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
 

+ 1 - 1
src/petals/models/bloom/block.py

@@ -7,7 +7,7 @@ from typing import Optional, Tuple
 
 import torch
 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
+from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor
 
 from petals.utils.misc import is_dummy
 

+ 1 - 1
src/petals/models/bloom/config.py

@@ -24,7 +24,7 @@ class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfi
     def from_pretrained(
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
-        logger.info("Make sure you follow the BLOOM's terms of use: https://bit.ly/bloom-license")
+        logger.info("Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license")
 
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
         if loading_from_repo and dht_prefix is None:

+ 2 - 0
src/petals/models/llama/__init__.py

@@ -5,11 +5,13 @@ from petals.models.llama.model import (
     DistributedLlamaForSequenceClassification,
     DistributedLlamaModel,
 )
+from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
 from petals.utils.auto_config import register_model_classes
 
 register_model_classes(
     config=DistributedLlamaConfig,
     model=DistributedLlamaModel,
     model_for_causal_lm=DistributedLlamaForCausalLM,
+    model_for_speculative=DistributedLlamaForSpeculativeGeneration,
     model_for_sequence_classification=DistributedLlamaForSequenceClassification,
 )

+ 2 - 2
src/petals/models/llama/block.py

@@ -15,7 +15,6 @@ from transformers.models.llama.modeling_llama import (
     LlamaConfig,
     LlamaDecoderLayer,
     LlamaMLP,
-    LlamaModel,
     LlamaRMSNorm,
     repeat_kv,
     rotate_half,
@@ -132,7 +131,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
     def __init__(self, config: LlamaConfig):
         nn.Module.__init__(self)
         self.hidden_size = config.hidden_size
-        self.self_attn = OptimizedLlamaAttention(config=config)
+        self.self_attn = OptimizedLlamaAttention(config=config, layer_idx=0)
+        # layer_idx only matters for KV caching, and we re-implement it in Petals
         self.mlp = LlamaMLP(config)
         self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

+ 2 - 2
src/petals/models/llama/config.py

@@ -27,8 +27,8 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
         cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
     ):
         logger.info(
-            "Make sure you follow the LLaMA's terms of use: "
-            "https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
+            "Make sure you follow the Llama terms of use: "
+            "https://llama.meta.com/llama3/license, https://llama.meta.com/llama2/license"
         )
 
         loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)

+ 111 - 0
src/petals/models/llama/speculative_model.py

@@ -0,0 +1,111 @@
+from typing import Optional, Union
+
+import torch
+from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
+from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama import LlamaForCausalLM
+
+from petals.models.llama.config import DistributedLlamaConfig
+from petals.models.llama.model import DistributedLlamaForCausalLM
+
+
+class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
+    def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
+        DistributedLlamaForCausalLM.__init__(self, config)
+        self.small_model = small_model
+
+    def _sample(
+        self,
+        input_ids: torch.LongTensor,
+        logits_processor: LogitsProcessorList,
+        stopping_criteria: StoppingCriteriaList,
+        generation_config: GenerationConfig,
+        synced_gpus: bool,
+        streamer: Optional["BaseStreamer"],
+        logits_warper: Optional[LogitsProcessorList],
+        speculative_inference_iteration_size: int = 10,
+        **model_kwargs,
+    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
+        assert not generation_config.do_sample, "sample is not working for speculative generation now"
+        assert not synced_gpus, "synced_gpus is not working for speculative generation now"
+        assert (
+            not generation_config.return_dict_in_generate
+        ), "return_dict_in_generate is not working for speculative generation now"
+
+        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+
+        # keep track of which sequences are already finished
+        batch_size = input_ids.shape[0]
+        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
+        finished = False
+        firsts = True
+
+        while not finished:
+            speculative_inference_iteration_size = min(
+                speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
+            )
+            with torch.no_grad():
+                speculative_outputs = self.small_model.generate(
+                    input_ids,
+                    max_new_tokens=speculative_inference_iteration_size,
+                    do_sample=False,
+                )
+                speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]
+
+            full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
+            assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]
+
+            input_for_validation = full_sequence
+            if not firsts:
+                self.active_session.position = input_ids.shape[1] - 1
+                input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
+            else:
+                firsts = False
+            input_for_validation = input_for_validation[:, :-1]
+            with torch.no_grad():
+                precise_model_outputs = self(input_for_validation)
+            full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()
+
+            all_valid_tokens = []
+            first_token = None
+            for i in range(speculative_inference_iteration_size):
+                token_logits = full_token_logits[:, i, :]
+                token_scores = logits_processor(
+                    input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
+                )
+                valid_token = torch.argmax(token_scores, dim=-1)
+
+                if first_token is None:
+                    first_token = valid_token
+
+                if valid_token.item() == speculative_tokens[:, i].item():
+                    all_valid_tokens.append(valid_token.unsqueeze(-1))
+                else:
+                    break
+
+            if not all_valid_tokens and first_token is not None:
+                all_valid_tokens.append(first_token.unsqueeze(-1))
+            all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)
+
+            # finished sentences should have their next token be a padding token
+            if has_eos_stopping_criteria:
+                all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
+                    1 - unfinished_sequences
+                )
+
+            # update generated ids, model inputs, and length for next step
+            input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)
+
+            if streamer is not None:
+                streamer.put(all_valid_tokens.cpu())
+
+            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
+            finished = unfinished_sequences.max() == 0
+
+            del precise_model_outputs
+
+        if streamer is not None:
+            streamer.end()
+
+        return input_ids

+ 1 - 2
src/petals/models/mixtral/block.py

@@ -1,4 +1,3 @@
-import json
 from typing import Optional, Tuple
 
 import torch
@@ -8,7 +7,7 @@ from transformers.modeling_attn_mask_utils import (
     _prepare_4d_causal_attention_mask,
     _prepare_4d_causal_attention_mask_for_sdpa,
 )
-from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
+from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
 
 
 class WrappedMixtralBlock(MixtralDecoderLayer):

+ 3 - 0
src/petals/models/mixtral/model.py

@@ -55,6 +55,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
         output_hidden_states: Optional[bool] = None,
         output_router_logits: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ):
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -70,6 +71,8 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        if cache_position is not None:
+            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"

+ 1 - 1
src/petals/server/block_utils.py

@@ -32,7 +32,7 @@ def get_block_size(
             dtype is not None and quant_type is not None
         ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
 
-    with init_empty_weights(include_buffers=True):
+    with init_empty_weights(include_buffers=False):
         block = get_model_block(config)
         n_params = sum(param.numel() for param in block.parameters())
 

+ 0 - 5
src/petals/server/from_pretrained.py

@@ -64,10 +64,6 @@ def load_pretrained_block(
         max_disk_space=max_disk_space,
     )
 
-    # dummy load, check that keys match
-    report = block.load_state_dict(state_dict, strict=False)
-    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
-
     for param_name, _ in block.named_parameters():
         assert param_name in state_dict, f"{param_name} not in state dict"
         param = state_dict[param_name]
@@ -76,7 +72,6 @@ def load_pretrained_block(
         set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
 
     logger.info(f"Loaded {model_name} block {block_index}")
-    logger.debug(f"Details: {report}")
     return block
 
 

+ 1 - 0
src/petals/utils/__init__.py

@@ -3,5 +3,6 @@ from petals.utils.auto_config import (
     AutoDistributedModel,
     AutoDistributedModelForCausalLM,
     AutoDistributedModelForSequenceClassification,
+    AutoDistributedSpeculativeModel,
 )
 from petals.utils.dht import declare_active_modules, get_remote_module_infos

+ 5 - 0
src/petals/utils/auto_config.py

@@ -15,6 +15,7 @@ class _ModelClasses:
     config: Type[PretrainedConfig]
     model: Optional[Type[PreTrainedModel]] = None
     model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
+    model_for_speculative: Optional[Type[PreTrainedModel]] = None
     model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
 
 
@@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
     _mapping_field = "model_for_causal_lm"
 
 
+class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
+    _mapping_field = "model_for_speculative"
+
+
 class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_sequence_classification"

+ 1 - 1
src/petals/utils/convert_block.py

@@ -61,7 +61,7 @@ def convert_block(
     if adapters:
         from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
 
-        create_lora_adapter(block, quant_type=quant_type)
+        create_lora_adapter(block)
         for adapter_name in adapters:
             adapter_config, adapter_state_dict = load_peft(
                 adapter_name,

+ 43 - 48
src/petals/utils/peft.py

@@ -1,7 +1,7 @@
 import contextlib
 import re
 import time
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union
 
 import bitsandbytes as bnb
 import torch
@@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from peft.config import PeftConfig
 from peft.tuners import lora
-from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
+from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
 from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
@@ -25,6 +25,9 @@ from petals.utils.misc import get_size_in_bytes
 logger = get_logger(__name__)
 
 
+COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]
+
+
 def check_peft_repository(repo_id: str) -> bool:
     return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
 
@@ -151,6 +154,18 @@ class AdapterContextMixin:
     def active_adapter(self, value: Optional[str]):
         assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
 
+    @property
+    def active_adapters(self):
+        return [self._context_active_adapter]
+
+    def set_adapter(self, adapter_names) -> None:
+        """
+        In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus,
+        this code removes this functionality.
+        Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463
+        """
+        pass
+
 
 using_adapter = AdapterContextMixin.using_adapter
 
@@ -158,60 +173,39 @@ using_adapter = AdapterContextMixin.using_adapter
 class LoraLinear(AdapterContextMixin, lora.Linear):
     """LoRA linear layer that uses adapter selected via using_adapter"""
 
+    def __init__(self, base_layer, adapter_name: str):
+        nn.Module.__init__(self)
+        lora.LoraLayer.__init__(self, base_layer)
+
+        self._active_adapter = adapter_name
+        self.is_target_conv_1d_layer = False
+
 
-class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
+class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
     """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
 
 
-class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
+class LoraLinear4bit(LoraLinear, lora.Linear4bit):
     """LoRA linear 4-bit that uses adapter selected via using_adapter"""
 
 
-def create_lora_adapter(block, quant_type: QuantType):
-    for _, module in block.named_modules():
+def create_lora_adapter(block):
+    for module_name, module in block.named_modules():
+        if isinstance(module, LoraLinear):
+            continue
         for child_name, child in module.named_children():
-            lora_wrapped_child = None
-            if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
-                continue
-            if quant_type == QuantType.INT8:
-                kwargs = {
-                    "has_fp16_weights": False,
-                    "threshold": 6.0,
-                    "bias": hasattr(child, "bias") and child.bias is not None,
-                }
-                lora_wrapped_child = LoraLinear8bitLt(
-                    AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    **kwargs,
-                )
-            elif quant_type == QuantType.NF4:
-                kwargs = {
-                    "compress_statistics": True,
-                    "quant_type": "nf4",
-                    "blocksize": 64,
-                    "bias": hasattr(child, "bias") and child.bias is not None,
-                }
-                lora_wrapped_child = LoraLinear4bit(
-                    AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    **kwargs,
-                )
-                lora_wrapped_child.compute_dtype = child.compute_dtype
-            else:
-                bias = hasattr(child, "bias") and child.bias is not None
-                lora_wrapped_child = LoraLinear(
+            lora_class = None
+            if isinstance(child, nn.Linear):
+                lora_class = LoraLinear
+            elif isinstance(child, bnb.nn.Linear8bitLt):
+                lora_class = LoraLinear8bitLt
+            elif isinstance(child, bnb.nn.Linear4bit):
+                lora_class = LoraLinear4bit
+            if lora_class:
+                lora_wrapped_child = lora_class(
+                    child,
                     AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    bias=bias,
                 )
-            if lora_wrapped_child:
-                lora_wrapped_child.weight = child.weight
-                lora_wrapped_child.bias = child.bias
-                for p in lora_wrapped_child.parameters():
-                    p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)
 
 
@@ -240,6 +234,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                             adapter_name,
                             peft_config["r"],
                             peft_config["lora_alpha"],
+                            use_rslora=peft_config.get("use_rslora", False),
                             lora_dropout=peft_config["lora_dropout"],
                             init_lora_weights=peft_config["init_lora_weights"],
                         )
@@ -272,10 +267,10 @@ def estimate_adapter_memory_per_block(
     **load_peft_kwargs,
 ) -> int:
     """Get the number of extra bytes used to store a set of adapters per given block"""
-    with init_empty_weights(include_buffers=True):
+    with init_empty_weights(include_buffers=False):
         block = get_model_block(block_config)
         base_block_parameters = sum(p.numel() for p in block.parameters())
-        create_lora_adapter(block, quant_type=QuantType.NONE)
+        create_lora_adapter(block)
 
         for adapter in adapters:
             peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)

+ 50 - 2
tests/test_speculative_generation.py

@@ -2,8 +2,14 @@ import random
 
 import pytest
 import torch
+import transformers
 
-from petals import AutoDistributedConfig, RemoteSequential
+from petals import (
+    AutoDistributedConfig,
+    AutoDistributedSpeculativeModel,
+    DistributedLlamaForSpeculativeGeneration,
+    RemoteSequential,
+)
 from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
 from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
@@ -26,7 +32,6 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     with torch.inference_mode():
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
             initial_outputs_inference = sess.step(inputs)
-
             sess.position = 2
             secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
@@ -35,3 +40,46 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
     (outputs_local,) = ref_block(short_inputs)
 
     assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
+
+
+@pytest.fixture
+def noisy_model():
+    noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
+        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    lm_head = noisy_model.get_output_embeddings()
+    assert isinstance(lm_head, torch.nn.Linear)
+    with torch.no_grad():
+        lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
+    return noisy_model
+
+
+@pytest.fixture
+def model():
+    return transformers.AutoModelForCausalLM.from_pretrained(
+        MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+
+
+@pytest.fixture
+def tokenizer():
+    # We set use_fast=False since LlamaTokenizerFast is slow on load
+    return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
+
+
+@pytest.mark.forked
+@pytest.mark.skipif(
+    "llama" not in MODEL_NAME.lower(),
+    reason="Speculative generation now works only for llama models",
+)
+def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
+    speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
+    )
+
+    inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+
+    generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
+    generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)
+
+    assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)