Browse Source

Support peft LoRA adapters (#335)

Implement an option to deploy PEFT adapters to a server. Clients can set active_adapter=... to use these adapters.

---------

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Artem Chumachenko 2 years ago
parent
commit
b9f0a5467f

+ 6 - 5
.github/workflows/run-tests.yaml

@@ -33,10 +33,11 @@ jobs:
         run: |
         run: |
           export MODEL_NAME=bigscience/bloom-560m
           export MODEL_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
+          export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
 
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
-            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log &
+            --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
           SERVER1_PID=$!
           SERVER1_PID=$!
 
 
           sleep 5  # wait for the first server to initialize DHT
           sleep 5  # wait for the first server to initialize DHT
@@ -45,17 +46,17 @@ jobs:
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
 
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
           SERVER2_PID=$!
           SERVER2_PID=$!
 
 
           sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
           sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
 
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
           SERVER3_PID=$!
           SERVER3_PID=$!
 
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
           SERVER4_PID=$!
           SERVER4_PID=$!
 
 
           tail -n 100 -f server*.log &
           tail -n 100 -f server*.log &

+ 2 - 0
setup.cfg

@@ -46,6 +46,8 @@ install_requires =
     cpufeature>=0.2.0
     cpufeature>=0.2.0
     packaging>=20.9
     packaging>=20.9
     sentencepiece>=0.1.99
     sentencepiece>=0.1.99
+    peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
+    safetensors>=0.3.1
 
 
 [options.extras_require]
 [options.extras_require]
 dev =
 dev =

+ 2 - 0
src/petals/cli/run_server.py

@@ -146,6 +146,8 @@ def main():
                         help="Skip checking this server's reachability via health.petals.ml "
                         help="Skip checking this server's reachability via health.petals.ml "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "the check is skipped by default. Use this option only if you know what you are doing")
                              "the check is skipped by default. Use this option only if you know what you are doing")
+    
+    parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
 
 
     # fmt:on
     # fmt:on
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())

+ 2 - 1
src/petals/client/remote_sequential.py

@@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
         dht: Optional[DHT] = None,
         dht: Optional[DHT] = None,
         start_block: Optional[int] = None,
         start_block: Optional[int] = None,
         end_block: Optional[int] = None,
         end_block: Optional[int] = None,
+        **kwargs,
     ):
     ):
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
@@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
             if end_block is None:
             if end_block is None:
                 end_block = self.config.num_hidden_layers
                 end_block = self.config.num_hidden_layers
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
-            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
+            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
         self.sequence_manager = sequence_manager
         self.sequence_manager = sequence_manager
 
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):

+ 7 - 3
src/petals/client/routing/sequence_manager.py

@@ -43,6 +43,7 @@ class SequenceManagerConfig:
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     max_backoff: float = 60  # limit maximal sleep time between retries to this value
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
+    active_adapter: Optional[str] = None
 
 
 
 
 @dataclasses.dataclass
 @dataclasses.dataclass
@@ -78,6 +79,7 @@ class RemoteSequenceManager:
         *,
         *,
         dht: Optional[DHT] = None,
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
         state: Optional[SequenceManagerState] = None,
+        active_adapter: Optional[str] = None,
     ):
     ):
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@@ -115,7 +117,9 @@ class RemoteSequenceManager:
         if state.sequence_info.last_updated_time is None:
         if state.sequence_info.last_updated_time is None:
             # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
             # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
             # in the first _update() instead of the latest ones. This makes the first .update() faster.
             # in the first _update() instead of the latest ones. This makes the first .update() faster.
-            petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
+            petals.dht_utils.get_remote_module_infos(
+                self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
+            )
             self._need_latest_infos = False
             self._need_latest_infos = False
         else:
         else:
             assert block_uids == state.sequence_info.block_uids
             assert block_uids == state.sequence_info.block_uids
@@ -179,7 +183,7 @@ class RemoteSequenceManager:
     def _update(self):
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
         """Perform an immediate and synchronous refresh, may take time"""
         new_block_infos = petals.dht_utils.get_remote_module_infos(
         new_block_infos = petals.dht_utils.get_remote_module_infos(
-            self.dht, self.block_uids, latest=self._need_latest_infos
+            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
         )
         )
         self._need_latest_infos = True  # All future _update() should use latest infos
         self._need_latest_infos = True  # All future _update() should use latest infos
 
 
@@ -307,7 +311,7 @@ class RemoteSequenceManager:
         :param kwargs: additional request context, such as remote peer ID
         :param kwargs: additional request context, such as remote peer ID
         :returns: msgpack-serialized metadata dict that will be passed alongside a given request
         :returns: msgpack-serialized metadata dict that will be passed alongside a given request
         """
         """
-        return dict(points=self.policy.get_points(protocol, *args, **kwargs))
+        return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
 
 
     def shutdown(self):
     def shutdown(self):
         self._thread.shutdown()
         self._thread.shutdown()

+ 2 - 1
src/petals/data_structures.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import dataclasses
 import dataclasses
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Optional, Tuple
 
 
 from hivemind import PeerID
 from hivemind import PeerID
 from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.expert_uid import ExpertUID
@@ -57,3 +57,4 @@ class InferenceMetadata:
     uid: ExpertUID
     uid: ExpertUID
     prefix_length: int
     prefix_length: int
     cache_handles: Tuple[Handle, ...]
     cache_handles: Tuple[Handle, ...]
+    active_adapter: Optional[str]

+ 26 - 4
src/petals/dht_utils.py

@@ -22,6 +22,7 @@ def declare_active_modules(
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
     state: ServerState,
     state: ServerState,
     throughput: float,
     throughput: float,
+    adapters: Optional[Sequence[str]] = None,
     wait: bool = True,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
     """
@@ -39,6 +40,7 @@ def declare_active_modules(
         uids = list(uids)
         uids = list(uids)
     for uid in uids:
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
+
     return dht.run_coroutine(
     return dht.run_coroutine(
         partial(
         partial(
             _declare_active_modules,
             _declare_active_modules,
@@ -46,6 +48,7 @@ def declare_active_modules(
             expiration_time=expiration_time,
             expiration_time=expiration_time,
             state=state,
             state=state,
             throughput=throughput,
             throughput=throughput,
+            adapters=list(adapters or []),
         ),
         ),
         return_future=not wait,
         return_future=not wait,
     )
     )
@@ -58,12 +61,13 @@ async def _declare_active_modules(
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
     state: ServerState,
     state: ServerState,
     throughput: float,
     throughput: float,
+    adapters: List[str],
 ) -> Dict[ModuleUID, bool]:
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
     return await node.store_many(
         keys=uids,
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[(state.value, throughput)] * len(uids),
+        values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
         expiration_time=expiration_time,
         expiration_time=expiration_time,
         num_workers=num_workers,
         num_workers=num_workers,
     )
     )
@@ -73,18 +77,30 @@ def get_remote_module_infos(
     dht: DHT,
     dht: DHT,
     uids: Sequence[ModuleUID],
     uids: Sequence[ModuleUID],
     expiration_time: Optional[DHTExpiration] = None,
     expiration_time: Optional[DHTExpiration] = None,
+    active_adapter: Optional[str] = None,
     *,
     *,
     latest: bool = False,
     latest: bool = False,
     return_future: bool = False,
     return_future: bool = False,
 ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
 ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
+        partial(
+            _get_remote_module_infos,
+            uids=uids,
+            active_adapter=active_adapter,
+            expiration_time=expiration_time,
+            latest=latest,
+        ),
         return_future=return_future,
         return_future=return_future,
     )
     )
 
 
 
 
 async def _get_remote_module_infos(
 async def _get_remote_module_infos(
-    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    active_adapter: Optional[str],
+    expiration_time: Optional[DHTExpiration],
+    latest: bool,
 ) -> List[Optional[RemoteModuleInfo]]:
 ) -> List[Optional[RemoteModuleInfo]]:
     if latest:
     if latest:
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
@@ -105,7 +121,13 @@ async def _get_remote_module_infos(
         for peer_id, server_info in metadata.value.items():
         for peer_id, server_info in metadata.value.items():
             try:
             try:
                 peer_id = PeerID.from_base58(peer_id)
                 peer_id = PeerID.from_base58(peer_id)
-                state, throughput = server_info.value
+                state, throughput = server_info.value[:2]
+                extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
+                adapters = extra_info.get("adapters", [])
+                if bool(active_adapter) and active_adapter not in adapters:
+                    logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
+                    continue
+
                 if not (
                 if not (
                     isinstance(state, int)
                     isinstance(state, int)
                     and isinstance(throughput, float)
                     and isinstance(throughput, float)

+ 26 - 1
src/petals/server/backend.py

@@ -2,8 +2,9 @@ from __future__ import annotations
 
 
 from collections import Counter
 from collections import Counter
 from itertools import chain
 from itertools import chain
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
 
+import peft
 import torch
 import torch
 from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind.moe.expert_uid import ExpertUID
 from hivemind.moe.expert_uid import ExpertUID
@@ -80,6 +81,18 @@ class TransformerBackend(ModuleBackend):
             cache_tensors.extend((keys, values))
             cache_tensors.extend((keys, values))
         return cache_tensors
         return cache_tensors
 
 
+    def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
+        *inputs, active_adapter = inputs
+        if not self.load_adapter_(active_adapter):
+            raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
+        return super().forward(*inputs)
+
+    def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
+        *inputs, active_adapter = inputs
+        if not self.load_adapter_(active_adapter):
+            raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
+        return super().backward(*inputs)
+
     @torch.inference_mode()
     @torch.inference_mode()
     def inference_step(
     def inference_step(
         self,
         self,
@@ -88,6 +101,8 @@ class TransformerBackend(ModuleBackend):
         inference_info: InferenceMetadata,
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+        if not self.load_adapter_(inference_info.active_adapter):
+            raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
             layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
@@ -139,6 +154,16 @@ class TransformerBackend(ModuleBackend):
         for p in self.module.parameters():
         for p in self.module.parameters():
             p.data = dummy
             p.data = dummy
 
 
+    def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
+        """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
+        adapter_was_loaded = False
+        for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
+            if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
+                layer.active_adapter = active_adapter  # empty string for no adapter
+                if active_adapter in layer.lora_A.keys():
+                    adapter_was_loaded = True
+        return adapter_was_loaded or not active_adapter
+
 
 
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
     """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""

+ 31 - 7
src/petals/server/handler.py

@@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
                 max_length = metadata.get("max_length")
                 max_length = metadata.get("max_length")
+                active_adapter = metadata.get("active_adapter", "")
                 points = metadata.get("points", 0)
                 points = metadata.get("points", 0)
                 session_id = metadata.get("session_id")
                 session_id = metadata.get("session_id")
 
 
@@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         )
                         )
 
 
                         inference_infos = tuple(
                         inference_infos = tuple(
-                            InferenceMetadata(uid, prefix_length, tuple(handles))
+                            InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
                             for uid, handles in zip(requested_uids, cache_handles)
                             for uid, handles in zip(requested_uids, cache_handles)
                         )
                         )
 
 
@@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             points = metadata.get("points", 0)
             assert isinstance(
             assert isinstance(
                 points, (float, int)
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
 
             hidden_states = await _rpc_forward(
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             )
             return runtime_pb2.ExpertResponse(
             return runtime_pb2.ExpertResponse(
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
                 tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_forward_stream", requested_uids, context)
             self._log_request("rpc_forward_stream", requested_uids, context)
 
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             points = metadata.get("points", 0)
             assert isinstance(
             assert isinstance(
                 points, (float, int)
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
 
             hidden_states = await _rpc_forward(
             hidden_states = await _rpc_forward(
-                *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_inputs,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             )
 
 
             # Split the serialized_output for streaming and respond to client
             # Split the serialized_output for streaming and respond to client
@@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             points = metadata.get("points", 0)
             assert isinstance(
             assert isinstance(
                 points, (float, int)
                 points, (float, int)
             ), f"rpc_backward should have number of points as number or None, got {points}"
             ), f"rpc_backward should have number of points as number or None, got {points}"
 
 
             grads = await _rpc_backward(
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_tensors,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             )
 
 
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
             return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler):
             self._log_request("rpc_backward_stream", requested_uids, context)
             self._log_request("rpc_backward_stream", requested_uids, context)
 
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+            active_adapter = metadata.get("active_adapter", "")
             points = metadata.get("points", 0)
             points = metadata.get("points", 0)
             assert isinstance(
             assert isinstance(
                 points, (float, int)
                 points, (float, int)
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
             ), f"rpc_backward_stream should have number of points as number or None, got {points}"
 
 
             grads = await _rpc_backward(
             grads = await _rpc_backward(
-                *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+                *flat_tensors,
+                requested_backends=requested_backends,
+                prioritizer=self._prioritizer,
+                active_adapter=active_adapter,
+                points=points,
             )
             )
             # Split the serialized_grad_inputs for streaming and respond
             # Split the serialized_grad_inputs for streaming and respond
             for tensor in self._serialize_grads(grads, requested_backends, metadata):
             for tensor in self._serialize_grads(grads, requested_backends, metadata):
@@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 async def _rpc_forward(
 async def _rpc_forward(
     *flat_tensors: torch.Tensor,
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
     requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
     points: int = 0,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
@@ -585,6 +607,7 @@ async def _rpc_forward(
         )
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
         (hidden_states,) = await backend.forward_pool.submit_task(
             hidden_states,
             hidden_states,
+            active_adapter,
             priority=priority,
             priority=priority,
         )
         )
         assert isinstance(hidden_states, torch.Tensor)
         assert isinstance(hidden_states, torch.Tensor)
@@ -598,6 +621,7 @@ async def _rpc_forward(
 async def _rpc_backward(
 async def _rpc_backward(
     *flat_tensors: torch.Tensor,
     *flat_tensors: torch.Tensor,
     requested_backends: Sequence[TransformerBackend],
     requested_backends: Sequence[TransformerBackend],
+    active_adapter: str = "",
     prioritizer: TaskPrioritizerBase,
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
@@ -623,7 +647,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
         priority = prioritizer.prioritize(
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
         )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
 
 
         assert isinstance(inputs, torch.Tensor)
         assert isinstance(inputs, torch.Tensor)
 
 
@@ -639,7 +663,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
 
 
         assert isinstance(grad_outputs, torch.Tensor)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
         if not is_dummy(prompt):

+ 26 - 1
src/petals/server/server.py

@@ -81,6 +81,7 @@ class Server:
         dht_client_mode: Optional[bool] = None,
         dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_relay: bool = True,
         use_auto_relay: bool = True,
         use_auto_relay: bool = True,
+        adapters: Optional[List[str]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -218,6 +219,8 @@ class Server:
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_block_selection_delay = mean_block_selection_delay
 
 
+        self.adapters = adapters
+
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def _choose_num_blocks(self) -> int:
     def _choose_num_blocks(self) -> int:
@@ -291,6 +294,7 @@ class Server:
                 quant_type=self.quant_type,
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
                 should_validate_reachability=self.should_validate_reachability,
+                adapters=self.adapters,
                 start=True,
                 start=True,
             )
             )
             try:
             try:
@@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread):
         quant_type: QuantType,
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
         should_validate_reachability: bool,
+        adapters: Optional[List[str]] = None,
         **kwargs,
         **kwargs,
     ) -> ModuleContainer:
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
@@ -391,6 +396,7 @@ class ModuleContainer(threading.Thread):
             module_uids,
             module_uids,
             dht,
             dht,
             ServerState.JOINING,
             ServerState.JOINING,
+            adapters=adapters,
             throughput=throughput,
             throughput=throughput,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
@@ -415,7 +421,19 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                     max_disk_space=max_disk_space,
                 )
                 )
-                block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
+                block = convert_block(
+                    block,
+                    block_index,
+                    block_config,
+                    tensor_parallel_devices,
+                    device,
+                    quant_type,
+                    adapters=adapters,
+                    freeze=True,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
+                )
                 blocks[module_uid] = TransformerBackend(
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     module_uid,
                     block,
                     block,
@@ -452,6 +470,7 @@ class ModuleContainer(threading.Thread):
                 expiration_time=get_dht_time() + expiration,
                 expiration_time=get_dht_time() + expiration,
                 state=ServerState.OFFLINE,
                 state=ServerState.OFFLINE,
                 throughput=throughput,
                 throughput=throughput,
+                adapters=adapters,
             )
             )
             logger.info(f"Announced that blocks {module_uids} are offline")
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
             raise
@@ -465,6 +484,7 @@ class ModuleContainer(threading.Thread):
             dht,
             dht,
             dht_prefix,
             dht_prefix,
             blocks,
             blocks,
+            adapters=adapters,
             throughput=throughput,
             throughput=throughput,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
@@ -480,6 +500,7 @@ class ModuleContainer(threading.Thread):
         inference_max_length: int,
         inference_max_length: int,
         num_handlers: int,
         num_handlers: int,
         throughput: float,
         throughput: float,
+        adapters: Optional[Sequence[str]],
         update_period: float,
         update_period: float,
         expiration: Optional[float] = None,
         expiration: Optional[float] = None,
         request_timeout: float,
         request_timeout: float,
@@ -517,6 +538,7 @@ class ModuleContainer(threading.Thread):
             list(self.module_backends.keys()),
             list(self.module_backends.keys()),
             dht,
             dht,
             ServerState.ONLINE,
             ServerState.ONLINE,
+            adapters=adapters,
             throughput=throughput,
             throughput=throughput,
             update_period=update_period,
             update_period=update_period,
             expiration=expiration,
             expiration=expiration,
@@ -616,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
         module_uids: List[str],
         module_uids: List[str],
         dht: DHT,
         dht: DHT,
         state: ServerState,
         state: ServerState,
+        adapters: Optional[Sequence[str]],
         *,
         *,
         throughput: float,
         throughput: float,
         update_period: float = 30,
         update_period: float = 30,
@@ -626,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
         self.module_uids = module_uids
         self.module_uids = module_uids
         self.dht = dht
         self.dht = dht
         self.state = state
         self.state = state
+        self.adapters = adapters
         self.throughput = throughput
         self.throughput = throughput
         self.update_period = update_period
         self.update_period = update_period
         self.expiration = expiration
         self.expiration = expiration
@@ -639,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
                 expiration_time=get_dht_time() + self.expiration,
                 expiration_time=get_dht_time() + self.expiration,
                 state=self.state,
                 state=self.state,
                 throughput=self.throughput,
                 throughput=self.throughput,
+                adapters=self.adapters,
             )
             )
             if self.stop.wait(self.update_period):
             if self.stop.wait(self.update_period):
                 break
                 break

+ 1 - 1
src/petals/server/throughput.py

@@ -172,7 +172,7 @@ def measure_compute_rps(
         tensor_parallel_devices = (device,)
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
     with torch.inference_mode():
         block = config.block_class(config).to(dtype)
         block = config.block_class(config).to(dtype)
-        block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+        block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
 
 
         cache = None
         cache = None
         elapsed = 0
         elapsed = 0

+ 17 - 8
src/petals/utils/convert_block.py

@@ -3,8 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
 """
 """
 import os
 import os
 import re
 import re
-from enum import Enum
-from typing import Sequence
+from typing import List, Optional, Sequence
 
 
 import tensor_parallel as tp
 import tensor_parallel as tp
 import torch
 import torch
@@ -13,23 +12,23 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tensor_parallel.slicing_configs import get_bloom_config
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
+from petals.utils.misc import QuantType
+from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
+
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class QuantType(Enum):
-    NONE = 0
-    INT8 = 1  # 8-bit as in the LLM.int8() paper
-    NF4 = 2  # 4-bit as in the QLoRA paper
-
-
 def convert_block(
 def convert_block(
     block: nn.Module,
     block: nn.Module,
+    block_index: int,
     config: PretrainedConfig,
     config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
     output_device: torch.device,
     quant_type: QuantType,
     quant_type: QuantType,
     freeze: bool = True,
     freeze: bool = True,
+    adapters: Optional[List[str]] = None,
+    **kwargs,
 ) -> tp.TensorParallel:
 ) -> tp.TensorParallel:
     """
     """
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@@ -56,6 +55,16 @@ def convert_block(
     for shard, device in zip(block.module_shards, block.devices):
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
         shard.to(device)
 
 
+    if adapters:
+        create_lora_adapter(block, quant_type=quant_type)
+        for adapter_name in adapters:
+            adapter_config, adapter_state_dict = load_peft(
+                adapter_name,
+                block_idx=block_index,
+                **kwargs,
+            )
+            add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
+
     return block
     return block
 
 
 
 

+ 9 - 0
src/petals/utils/misc.py

@@ -1,5 +1,14 @@
+from enum import Enum
+
 import torch
 import torch
 
 
+
+class QuantType(Enum):
+    NONE = 0
+    INT8 = 1  # 8-bit as in the LLM.int8() paper
+    NF4 = 2  # 4-bit as in the QLoRA paper
+
+
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
 
 
 
 

+ 208 - 0
src/petals/utils/peft.py

@@ -0,0 +1,208 @@
+import re
+import time
+from typing import List, Optional
+
+import bitsandbytes as bnb
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
+from peft.tuners import lora
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from safetensors import safe_open
+from safetensors.torch import load_file
+from transformers.utils import get_file_from_repo
+
+from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+from petals.utils.misc import QuantType
+
+logger = get_logger(__name__)
+
+
+def check_peft_repository(repo_id: str) -> bool:
+    fs = HfFileSystem()
+    list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
+    return len(list_of_files) > 0
+
+
+def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
+    tensors = dict()
+    is_tensors_found = dict()
+    common_layer_patter_re = (
+        ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
+    )
+    with safe_open(filepath, framework=framework, device=device) as f:
+        for k in f.keys():
+            if re.match(common_layer_patter_re, k):
+                is_tensors_found[block_idx] = True
+                tensors[k] = f.get_tensor(k)
+        if not is_tensors_found.get(block_idx, False):
+            logger.warning(f"There is no peft weights for block {block_idx}")
+        return tensors
+
+
+def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
+    config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
+    if config_path is None:
+        raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
+    config = PeftConfig.from_json_file(config_path)
+
+    weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
+    if weight_path is None:
+        raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
+    if block_idx is None:
+        return config, load_file(weight_path)
+    return config, load_specific_module(block_idx, weight_path, device=device)
+
+
+def load_peft(
+    repo_id: str,
+    block_idx: Optional[int] = None,
+    device: Optional[int] = None,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+    delay: float = 30,
+):
+    # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
+
+    if not check_peft_repository(repo_id):
+        raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
+
+    try:
+        with allow_cache_reads(cache_dir):
+            return get_adapter_from_repo(
+                repo_id,
+                block_idx,
+                device,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=False,
+            )
+    except Exception:
+        logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
+
+    while True:
+        try:
+            with allow_cache_writes(cache_dir):
+                config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
+                config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
+                weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
+                weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
+
+                file_size = config_file_size + weight_file_size
+                if file_size is not None:
+                    free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                else:
+                    logger.warning(f"Failed to fetch size from peft repo {repo_id}")
+
+                return get_adapter_from_repo(
+                    repo_id,
+                    block_idx,
+                    device,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+        except Exception as e:
+            logger.warning(
+                f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
+            )
+            time.sleep(delay)
+
+
+def create_lora_adapter(block, quant_type: QuantType):
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            lora_wrapped_child = None
+            if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
+                continue
+            if quant_type == QuantType.INT8:
+                kwargs = {
+                    "has_fp16_weights": False,
+                    "threshold": 6.0,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
+                }
+                lora_wrapped_child = lora.Linear8bitLt(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    **kwargs,
+                )
+            elif quant_type == QuantType.NF4:
+                kwargs = {
+                    "compress_statistics": True,
+                    "quant_type": "nf4",
+                    "blocksize": 64,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
+                }
+                lora_wrapped_child = lora.Linear4bit(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    **kwargs,
+                )
+            else:
+                bias = hasattr(child, "bias") and child.bias is not None
+                lora_wrapped_child = lora.Linear(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    bias=bias,
+                )
+            if lora_wrapped_child:
+                lora_wrapped_child.active_adapter = None
+                lora_wrapped_child.weight = child.weight
+                lora_wrapped_child.bias = child.bias
+                for p in lora_wrapped_child.parameters():
+                    p.requires_grad = False
+                setattr(module, child_name, lora_wrapped_child)
+
+
+def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
+    assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
+                continue
+
+            if child_name in peft_config["target_modules"] or (
+                isinstance(peft_config["target_modules"], str)
+                and re.fullmatch(peft_config["target_modules"], child_name)
+            ):
+                is_lora_a_loaded = False
+                is_lora_b_loaded = False
+                for peft_key in peft_state_dict:
+                    if peft_key.find(child_name) == -1:
+                        continue
+
+                    if adapter_name not in child.lora_A:
+                        child.update_layer(
+                            adapter_name,
+                            peft_config["r"],
+                            peft_config["lora_alpha"],
+                            lora_dropout=peft_config["lora_dropout"],
+                            init_lora_weights=peft_config["init_lora_weights"],
+                        )
+                        child.train(False)
+                        if peft_config["lora_dropout"] > 0:
+                            logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
+                        for p in child.parameters():
+                            p.requires_grad = False
+
+                    if peft_key.endswith(".lora_A.weight"):
+                        child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
+                        is_lora_a_loaded = True
+                    elif peft_key.endswith(".lora_A.bias"):
+                        raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
+                    elif peft_key.endswith(".lora_B.weight"):
+                        child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
+                        is_lora_b_loaded = True
+                    elif peft_key.endswith(".lora_B.bias"):
+                        raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
+
+                if is_lora_a_loaded and is_lora_b_loaded:
+                    logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")

+ 11 - 2
tests/test_full_model.py

@@ -1,3 +1,4 @@
+import peft
 import pytest
 import pytest
 import torch
 import torch
 import transformers
 import transformers
@@ -12,11 +13,16 @@ logger = get_logger(__name__)
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
+@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
-def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
+def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
     model = DistributedBloomForCausalLM.from_pretrained(
-        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+        MODEL_NAME,
+        initial_peers=INITIAL_PEERS,
+        low_cpu_mem_usage=True,
+        torch_dtype=torch.float32,
+        active_adapter=ADAPTER_NAME if use_peft else None,
     )
     )
     config = model.config
     config = model.config
     assert isinstance(model, DistributedBloomForCausalLM)
     assert isinstance(model, DistributedBloomForCausalLM)
@@ -54,6 +60,9 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
             ref_model = transformers.BloomForCausalLM.from_pretrained(
             ref_model = transformers.BloomForCausalLM.from_pretrained(
                 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
                 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
             )
             )
+            if use_peft:
+                ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
+                ref_model.train(False)
             if config.vocab_size < ref_model.config.vocab_size:
             if config.vocab_size < ref_model.config.vocab_size:
                 ref_model.resize_token_embeddings(config.vocab_size)
                 ref_model.resize_token_embeddings(config.vocab_size)
                 logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
                 logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

+ 66 - 0
tests/test_peft.py

@@ -0,0 +1,66 @@
+import os
+import shutil
+
+import pytest
+from huggingface_hub import snapshot_download
+
+from petals.utils.peft import check_peft_repository, load_peft
+
+UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
+SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
+TMP_CACHE_DIR = "tmp_cache/"
+
+
+def clear_dir(path_to_dir):
+    shutil.rmtree(path_to_dir)
+    os.mkdir(path_to_dir)
+
+
+def dir_empty(path_to_dir):
+    files = os.listdir(path_to_dir)
+    return len(files) == 0
+
+
+@pytest.mark.forked
+def test_check_peft():
+    assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
+    assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
+
+
+@pytest.mark.forked
+def test_load_noncached(tmpdir):
+    clear_dir(tmpdir)
+    with pytest.raises(Exception):
+        load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
+
+    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
+
+
+@pytest.mark.forked
+def test_load_cached(tmpdir):
+    clear_dir(tmpdir)
+    snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+
+@pytest.mark.forked
+def test_load_layer_exists(tmpdir):
+    clear_dir(tmpdir)
+
+    load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
+
+
+@pytest.mark.forked
+def test_load_layer_nonexists(tmpdir):
+    clear_dir(tmpdir)
+
+    load_peft(
+        SAFE_PEFT_REPO,
+        block_idx=1337,
+        cache_dir=tmpdir,
+    )

+ 2 - 0
tests/test_utils.py

@@ -11,3 +11,5 @@ if not MODEL_NAME:
     raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
     raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
 
 
 REF_NAME = os.environ.get("REF_NAME")
 REF_NAME = os.environ.get("REF_NAME")
+
+ADAPTER_NAME = os.environ.get("ADAPTER_NAME")