Răsfoiți Sursa

Remove peft dependency for AMD GPUs

Aleksandr Borzunov 2 ani în urmă
părinte
comite
6b38bc89ef

+ 0 - 1
setup.cfg

@@ -46,7 +46,6 @@ install_requires =
     cpufeature>=0.2.0
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft>=0.4.0
     safetensors>=0.3.1
     Dijkstar>=2.6.0
 

+ 3 - 11
src/petals/server/backend.py

@@ -35,10 +35,6 @@ class TransformerBackend(ModuleBackend):
         max_chunk_size_bytes: int,
         **kwargs,
     ):
-        import petals.utils.peft as _peft_module
-
-        self._peft_module = _peft_module
-
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, TensorParallel)
         self.config = config
@@ -98,13 +94,11 @@ class TransformerBackend(ModuleBackend):
 
     def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        with self._peft_module.using_adapter(active_adapter):
-            return super().forward(*inputs)
+        return super().forward(*inputs)
 
     def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        with self._peft_module.using_adapter(active_adapter):
-            return super().backward(*inputs)
+        return super().backward(*inputs)
 
     @torch.inference_mode()
     def inference_step(
@@ -116,9 +110,7 @@ class TransformerBackend(ModuleBackend):
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
         seq_len = hidden_states.shape[1]
 
-        with self.memory_cache.use_cache(
-            *inference_info.cache_handles
-        ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
+        with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
 
             # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`

+ 1 - 11
src/petals/server/server.py

@@ -278,17 +278,7 @@ class Server:
         block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
         total_memory_per_block = block_size + self._cache_bytes_per_block
         if self.adapters:
-            # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
-            from petals.utils.peft import estimate_adapter_memory_per_block
-
-            total_memory_per_block += estimate_adapter_memory_per_block(
-                self.block_config,
-                self.torch_dtype,
-                self.adapters,
-                token=self.token,
-                cache_dir=self.cache_dir,
-                max_disk_space=self.max_disk_space,
-            )
+            raise RuntimeError("LoRA adapters are not supported on AMD GPUs")
 
         num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"

+ 1 - 10
src/petals/utils/convert_block.py

@@ -59,16 +59,7 @@ def convert_block(
         shard.to(device)
 
     if adapters:
-        from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
-
-        create_lora_adapter(block, quant_type=quant_type)
-        for adapter_name in adapters:
-            adapter_config, adapter_state_dict = load_peft(
-                adapter_name,
-                block_idx=block_index,
-                **kwargs,
-            )
-            add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
+        raise RuntimeError("LoRA adapters are not supported on AMD GPUs")
 
     return block
 

+ 2 - 0
tests/test_peft.py

@@ -4,6 +4,8 @@ import shutil
 import pytest
 from huggingface_hub import snapshot_download
 
+pytest.skip("LoRA adapters are not supported on AMD GPUs", allow_module_level=True)
+
 from petals.utils.peft import check_peft_repository, load_peft
 
 UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"