Quellcode durchsuchen

Support peft LoRA adapters (#335)

Implement an option to deploy PEFT adapters to a server. Clients can set active_adapter=... to use these adapters.

---------

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Artem Chumachenko vor 2 Jahren
Ursprung
Commit
b9f0a5467f

+ 6 - 5
.github/workflows/run-tests.yaml

@@ -33,10 +33,11 @@ jobs:
         run: |
           export MODEL_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
+          export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
-            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log &
+            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
           SERVER1_PID=$!
 
           sleep 5  # wait for the first server to initialize DHT
@@ -45,17 +46,17 @@ jobs:
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
           SERVER2_PID=$!
 
           sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
           SERVER3_PID=$!
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
           SERVER4_PID=$!
 
           tail -n 100 -f server*.log &

+ 2 - 0
setup.cfg

@@ -46,6 +46,8 @@ install_requires =
     cpufeature>=0.2.0
     packaging>=20.9
     sentencepiece>=0.1.99
+    peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
+    safetensors>=0.3.1
 
 [options.extras_require]
 dev =

+ 2 - 0
src/petals/cli/run_server.py

@@ -146,6 +146,8 @@ def main():
                         help="Skip checking this server's reachability via health.petals.ml "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "the check is skipped by default. Use this option only if you know what you are doing")
+    
+    parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
 
     # fmt:on
     args = vars(parser.parse_args())

+ 2 - 1
src/petals/client/remote_sequential.py

@@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
         dht: Optional[DHT] = None,
         start_block: Optional[int] = None,
         end_block: Optional[int] = None,
+        **kwargs,
     ):
         super().__init__()
         self.config = config
@@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
             if end_block is None:
                 end_block = self.config.num_hidden_layers
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
-            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
+            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
         self.sequence_manager = sequence_manager
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):

+ 7 - 3
src/petals/client/routing/sequence_manager.py

@@ -43,6 +43,7 @@ class SequenceManagerConfig:
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
+    active_adapter: Optional[str] = None
 
 
 @dataclasses.dataclass
@@ -78,6 +79,7 @@ class RemoteSequenceManager:
         *,
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
+        active_adapter: Optional[str] = None,
     ):
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@@ -115,7 +117,9 @@ class RemoteSequenceManager:
         if state.sequence_info.last_updated_time is None:
             # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
             # in the first _update() instead of the latest ones. This makes the first .update() faster.
-            petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
+            petals.dht_utils.get_remote_module_infos(
+                self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
+            )
             self._need_latest_infos = False
         else:
             assert block_uids == state.sequence_info.block_uids
@@ -179,7 +183,7 @@ class RemoteSequenceManager:
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
         new_block_infos = petals.dht_utils.get_remote_module_infos(
-            self.dht, self.block_uids, latest=self._need_latest_infos
+            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
         )
         self._need_latest_infos = True  # All future _update() should use latest infos
 
@@ -307,7 +311,7 @@ class RemoteSequenceManager:
         :param kwargs: additional request context, such as remote peer ID
         :returns: msgpack-serialized metadata dict that will be passed alongside a given request
         """
-        return dict(points=self.policy.get_points(protocol, *args, **kwargs))
+        return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
 
     def shutdown(self):
         self._thread.shutdown()

+ 2 - 1
src/petals/data_structures.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import dataclasses
 from dataclasses import dataclass
 from enum import Enum
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Optional, Tuple
 
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
@@ -57,3 +57,4 @@ class InferenceMetadata:
     uid: ExpertUID
     prefix_length: int
     cache_handles: Tuple[Handle, ...]
+    active_adapter: Optional[str]

+ 26 - 4
src/petals/dht_utils.py

@@ -22,6 +22,7 @@ def declare_active_modules(
     expiration_time: DHTExpiration,
     state: ServerState,
     throughput: float,
+    adapters: Optional[Sequence[str]] = None,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
@@ -39,6 +40,7 @@ def declare_active_modules(
         uids = list(uids)
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
+
     return dht.run_coroutine(
         partial(
             _declare_active_modules,
@@ -46,6 +48,7 @@ def declare_active_modules(
             expiration_time=expiration_time,
             state=state,
             throughput=throughput,
+            adapters=list(adapters or []),
         ),
         return_future=not wait,
     )
@@ -58,12 +61,13 @@ async def _declare_active_modules(
     expiration_time: DHTExpiration,
     state: ServerState,
     throughput: float,
+    adapters: List[str],
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[(state.value, throughput)] * len(uids),
+        values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
         expiration_time=expiration_time,
         num_workers=num_workers,
     )
@@ -73,18 +77,30 @@ def get_remote_module_infos(
     dht: DHT,
     uids: Sequence[ModuleUID],
     expiration_time: Optional[DHTExpiration] = None,
+    active_adapter: Optional[str] = None,
     *,
     latest: bool = False,
     return_future: bool = False,
 ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
     return dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
+        partial(
+            _get_remote_module_infos,
+            uids=uids,
+            active_adapter=active_adapter,
+            expiration_time=expiration_time,
+            latest=latest,
+        ),
         return_future=return_future,
     )
 
 
 async def _get_remote_module_infos(
-    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    active_adapter: Optional[str],
+    expiration_time: Optional[DHTExpiration],
+    latest: bool,
 ) -> List[Optional[RemoteModuleInfo]]:
     if latest:
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
@@ -105,7 +121,13 @@ async def _get_remote_module_infos(
         for peer_id, server_info in metadata.value.items():
             try:
                 peer_id = PeerID.from_base58(peer_id)
-                state, throughput = server_info.value
+                state, throughput = server_info.value[:2]
+                extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
+                adapters = extra_info.get("adapters", [])
+                if bool(active_adapter) and active_adapter not in adapters:
+                    logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
+                    continue
+
                 if not (
                     isinstance(state, int)
                     and isinstance(throughput, float)

+ 26 - 1
src/petals/server/backend.py

@@ -2,8 +2,9 @@ from __future__ import annotations
 
 from collections import Counter
 from itertools import chain
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
+import peft
 import torch
 from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind.moe.expert_uid import ExpertUID
@@ -80,6 +81,18 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
         return cache_tensors
 
+    def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
+        *inputs, active_adapter = inputs
+        if not self.load_adapter_(active_adapter):
+            raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
+        return super().forward(*inputs)
+
+    def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
+        *inputs, active_adapter = inputs
+        if not self.load_adapter_(active_adapter):
+            raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
+        return super().backward(*inputs)
+
     @torch.inference_mode()
     def inference_step(
         self,
@@ -88,6 +101,8 @@ class TransformerBackend(ModuleBackend):
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+        if not self.load_adapter_(inference_info.active_adapter):
+            raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
@@ -139,6 +154,16 @@ class TransformerBackend(ModuleBackend):
         for p in self.module.parameters():
             p.data = dummy
 
+    def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
+        """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
+        adapter_was_loaded = False
+        for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
+            if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
+                layer.active_adapter = active_adapter  # empty string for no adapter
+                if active_adapter in layer.lora_A.keys():
+                    adapter_was_loaded = True
+        return adapter_was_loaded or not active_adapter
+
 
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""

+ 31 - 7
src/petals/server/handler.py

@@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
+                active_adapter = metadata.get("active_adapter", "")
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
 
@@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         )
 
                         inference_infos = tuple(
-                            InferenceMetadata(uid, prefix_length, tuple(handles))
+                            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
                             for uid, handles in zip(requested_uids, cache_handles)
                         )
 
@@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_forward_stream", requested_uids, context)
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
 
             # Split the serialized_output for streaming and respond to client
@@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_tensors,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
 
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_backward_stream", requested_uids, context)
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_tensors,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             # Split the serialized_grad_inputs for streaming and respond
             for tensor in self._serialize_grads(grads, requested_backends, metadata):
@@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> torch.Tensor:
@@ -585,6 +607,7 @@ async def _rpc_forward(
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
             hidden_states,
+            active_adapter,
             priority=priority,
         )
         assert isinstance(hidden_states, torch.Tensor)
@@ -598,6 +621,7 @@ async def _rpc_forward(
 async def _rpc_backward(
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
@@ -623,7 +647,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
 
         assert isinstance(inputs, torch.Tensor)
 
@@ -639,7 +663,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
 
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):

+ 26 - 1
src/petals/server/server.py

@@ -81,6 +81,7 @@ class Server:
         dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_auto_relay: bool = True,
+        adapters: Optional[List[str]] = None,
         **kwargs,
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -218,6 +219,8 @@ class Server:
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
 
+        self.adapters = adapters
+
         self.stop = threading.Event()
 
     def _choose_num_blocks(self) -> int:
@@ -291,6 +294,7 @@ class Server:
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
+                adapters=self.adapters,
                 start=True,
             )
             try:
@@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread):
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
+        adapters: Optional[List[str]] = None,
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
@@ -391,6 +396,7 @@ class ModuleContainer(threading.Thread):
             module_uids,
             dht,
             ServerState.JOINING,
+            adapters=adapters,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -415,7 +421,19 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
-                block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
+                block = convert_block(
+                    block,
+                    block_index,
+                    block_config,
+                    tensor_parallel_devices,
+                    device,
+                    quant_type,
+                    adapters=adapters,
+                    freeze=True,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
+                )
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
@@ -452,6 +470,7 @@ class ModuleContainer(threading.Thread):
                 expiration_time=get_dht_time() + expiration,
                 state=ServerState.OFFLINE,
                 throughput=throughput,
+                adapters=adapters,
             )
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
@@ -465,6 +484,7 @@ class ModuleContainer(threading.Thread):
             dht,
             dht_prefix,
             blocks,
+            adapters=adapters,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -480,6 +500,7 @@ class ModuleContainer(threading.Thread):
         inference_max_length: int,
         num_handlers: int,
         throughput: float,
+        adapters: Optional[Sequence[str]],
         update_period: float,
         expiration: Optional[float] = None,
         request_timeout: float,
@@ -517,6 +538,7 @@ class ModuleContainer(threading.Thread):
             list(self.module_backends.keys()),
             dht,
             ServerState.ONLINE,
+            adapters=adapters,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -616,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
         module_uids: List[str],
         dht: DHT,
         state: ServerState,
+        adapters: Optional[Sequence[str]],
         *,
         throughput: float,
         update_period: float = 30,
@@ -626,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
         self.module_uids = module_uids
         self.dht = dht
         self.state = state
+        self.adapters = adapters
         self.throughput = throughput
         self.update_period = update_period
         self.expiration = expiration
@@ -639,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
                 expiration_time=get_dht_time() + self.expiration,
                 state=self.state,
                 throughput=self.throughput,
+                adapters=self.adapters,
             )
             if self.stop.wait(self.update_period):
                 break

+ 1 - 1
src/petals/server/throughput.py

@@ -172,7 +172,7 @@ def measure_compute_rps(
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
         block = config.block_class(config).to(dtype)
-        block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+        block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
 
         cache = None
         elapsed = 0

+ 17 - 8
src/petals/utils/convert_block.py

@@ -3,8 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
 """
 import os
 import re
-from enum import Enum
-from typing import Sequence
+from typing import List, Optional, Sequence
 
 import tensor_parallel as tp
 import torch
@@ -13,23 +12,23 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 
+from petals.utils.misc import QuantType
+from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
-class QuantType(Enum):
-    NONE = 0
-    INT8 = 1  # 8-bit as in the LLM.int8() paper
-    NF4 = 2  # 4-bit as in the QLoRA paper
-
-
 def convert_block(
     block: nn.Module,
+    block_index: int,
     config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
     quant_type: QuantType,
     freeze: bool = True,
+    adapters: Optional[List[str]] = None,
+    **kwargs,
 ) -> tp.TensorParallel:
     """
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@@ -56,6 +55,16 @@ def convert_block(
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
 
+    if adapters:
+        create_lora_adapter(block, quant_type=quant_type)
+        for adapter_name in adapters:
+            adapter_config, adapter_state_dict = load_peft(
+                adapter_name,
+                block_idx=block_index,
+                **kwargs,
+            )
+            add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
+
     return block
 
 

+ 9 - 0
src/petals/utils/misc.py

@@ -1,5 +1,14 @@
+from enum import Enum
+
 import torch
 
+
+class QuantType(Enum):
+    NONE = 0
+    INT8 = 1  # 8-bit as in the LLM.int8() paper
+    NF4 = 2  # 4-bit as in the QLoRA paper
+
+
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
 
 

+ 208 - 0
src/petals/utils/peft.py

@@ -0,0 +1,208 @@
+import re
+import time
+from typing import List, Optional
+
+import bitsandbytes as bnb
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
+from peft.tuners import lora
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from safetensors import safe_open
+from safetensors.torch import load_file
+from transformers.utils import get_file_from_repo
+
+from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.misc import QuantType
+
+logger = get_logger(__name__)
+
+
+def check_peft_repository(repo_id: str) -> bool:
+    fs = HfFileSystem()
+    list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
+    return len(list_of_files) > 0
+
+
+def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
+    tensors = dict()
+    is_tensors_found = dict()
+    common_layer_patter_re = (
+        ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
+    )
+    with safe_open(filepath, framework=framework, device=device) as f:
+        for k in f.keys():
+            if re.match(common_layer_patter_re, k):
+                is_tensors_found[block_idx] = True
+                tensors[k] = f.get_tensor(k)
+        if not is_tensors_found.get(block_idx, False):
+            logger.warning(f"There is no peft weights for block {block_idx}")
+        return tensors
+
+
+def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
+    config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
+    if config_path is None:
+        raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
+    config = PeftConfig.from_json_file(config_path)
+
+    weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
+    if weight_path is None:
+        raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
+    if block_idx is None:
+        return config, load_file(weight_path)
+    return config, load_specific_module(block_idx, weight_path, device=device)
+
+
+def load_peft(
+    repo_id: str,
+    block_idx: Optional[int] = None,
+    device: Optional[int] = None,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+    delay: float = 30,
+):
+    # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
+
+    if not check_peft_repository(repo_id):
+        raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
+
+    try:
+        with allow_cache_reads(cache_dir):
+            return get_adapter_from_repo(
+                repo_id,
+                block_idx,
+                device,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=False,
+            )
+    except Exception:
+        logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
+
+    while True:
+        try:
+            with allow_cache_writes(cache_dir):
+                config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
+                config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
+                weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
+                weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
+
+                file_size = config_file_size + weight_file_size
+                if file_size is not None:
+                    free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                else:
+                    logger.warning(f"Failed to fetch size from peft repo {repo_id}")
+
+                return get_adapter_from_repo(
+                    repo_id,
+                    block_idx,
+                    device,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+        except Exception as e:
+            logger.warning(
+                f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
+            )
+            time.sleep(delay)
+
+
+def create_lora_adapter(block, quant_type: QuantType):
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            lora_wrapped_child = None
+            if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
+                continue
+            if quant_type == QuantType.INT8:
+                kwargs = {
+                    "has_fp16_weights": False,
+                    "threshold": 6.0,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
+                }
+                lora_wrapped_child = lora.Linear8bitLt(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    **kwargs,
+                )
+            elif quant_type == QuantType.NF4:
+                kwargs = {
+                    "compress_statistics": True,
+                    "quant_type": "nf4",
+                    "blocksize": 64,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
+                }
+                lora_wrapped_child = lora.Linear4bit(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    **kwargs,
+                )
+            else:
+                bias = hasattr(child, "bias") and child.bias is not None
+                lora_wrapped_child = lora.Linear(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    bias=bias,
+                )
+            if lora_wrapped_child:
+                lora_wrapped_child.active_adapter = None
+                lora_wrapped_child.weight = child.weight
+                lora_wrapped_child.bias = child.bias
+                for p in lora_wrapped_child.parameters():
+                    p.requires_grad = False
+                setattr(module, child_name, lora_wrapped_child)
+
+
+def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
+    assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
+                continue
+
+            if child_name in peft_config["target_modules"] or (
+                isinstance(peft_config["target_modules"], str)
+                and re.fullmatch(peft_config["target_modules"], child_name)
+            ):
+                is_lora_a_loaded = False
+                is_lora_b_loaded = False
+                for peft_key in peft_state_dict:
+                    if peft_key.find(child_name) == -1:
+                        continue
+
+                    if adapter_name not in child.lora_A:
+                        child.update_layer(
+                            adapter_name,
+                            peft_config["r"],
+                            peft_config["lora_alpha"],
+                            lora_dropout=peft_config["lora_dropout"],
+                            init_lora_weights=peft_config["init_lora_weights"],
+                        )
+                        child.train(False)
+                        if peft_config["lora_dropout"] > 0:
+                            logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
+                        for p in child.parameters():
+                            p.requires_grad = False
+
+                    if peft_key.endswith(".lora_A.weight"):
+                        child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
+                        is_lora_a_loaded = True
+                    elif peft_key.endswith(".lora_A.bias"):
+                        raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
+                    elif peft_key.endswith(".lora_B.weight"):
+                        child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
+                        is_lora_b_loaded = True
+                    elif peft_key.endswith(".lora_B.bias"):
+                        raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
+
+                if is_lora_a_loaded and is_lora_b_loaded:
+                    logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")

+ 11 - 2
tests/test_full_model.py

@@ -1,3 +1,4 @@
+import peft
 import pytest
 import torch
 import transformers
@@ -12,11 +13,16 @@ logger = get_logger(__name__)
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
-def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
+def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
-        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+        MODEL_NAME,
+        initial_peers=INITIAL_PEERS,
+        low_cpu_mem_usage=True,
+        torch_dtype=torch.float32,
+        active_adapter=ADAPTER_NAME if use_peft else None,
     )
     config = model.config
     assert isinstance(model, DistributedBloomForCausalLM)
@@ -54,6 +60,9 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
             ref_model = transformers.BloomForCausalLM.from_pretrained(
                 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
             )
+            if use_peft:
+                ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
+                ref_model.train(False)
             if config.vocab_size < ref_model.config.vocab_size:
                 ref_model.resize_token_embeddings(config.vocab_size)
                 logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

+ 66 - 0
tests/test_peft.py

@@ -0,0 +1,66 @@
+import os
+import shutil
+
+import pytest
+from huggingface_hub import snapshot_download
+
+from petals.utils.peft import check_peft_repository, load_peft
+
+UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
+SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
+TMP_CACHE_DIR = "tmp_cache/"
+
+
+def clear_dir(path_to_dir):
+    shutil.rmtree(path_to_dir)
+    os.mkdir(path_to_dir)
+
+
+def dir_empty(path_to_dir):
+    files = os.listdir(path_to_dir)
+    return len(files) == 0
+
+
+@pytest.mark.forked
+def test_check_peft():
+    assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
+    assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
+
+
+@pytest.mark.forked
+def test_load_noncached(tmpdir):
+    clear_dir(tmpdir)
+    with pytest.raises(Exception):
+        load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
+
+    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
+
+
+@pytest.mark.forked
+def test_load_cached(tmpdir):
+    clear_dir(tmpdir)
+    snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+
+@pytest.mark.forked
+def test_load_layer_exists(tmpdir):
+    clear_dir(tmpdir)
+
+    load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
+
+
+@pytest.mark.forked
+def test_load_layer_nonexists(tmpdir):
+    clear_dir(tmpdir)
+
+    load_peft(
+        SAFE_PEFT_REPO,
+        block_idx=1337,
+        cache_dir=tmpdir,
+    )

+ 2 - 0
tests/test_utils.py

@@ -11,3 +11,5 @@ if not MODEL_NAME:
     raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
 
 REF_NAME = os.environ.get("REF_NAME")
+
+ADAPTER_NAME = os.environ.get("ADAPTER_NAME")