Pārlūkot izejas kodu

Add first version

Artem Chumachenko 1 gadu atpakaļ
vecāks
revīzija
01c3cf8d15

+ 6 - 0
src/petals/server/backend.py

@@ -29,10 +29,13 @@ class TransformerBackend(ModuleBackend):
     def __init__(
         self,
         *args,
+        block_index: int,
         config: PretrainedConfig,
         memory_cache: MemoryCache,
         backend_dtype: torch.dtype,
         max_chunk_size_bytes: int,
+        cache_dir: str,
+        max_disk_space: int,
         **kwargs,
     ):
         import petals.utils.peft as _peft_module
@@ -41,9 +44,12 @@ class TransformerBackend(ModuleBackend):
 
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, TensorParallel)
+        self.block_index = block_index
         self.config = config
         self.memory_cache = memory_cache
         self.max_chunk_size_bytes = max_chunk_size_bytes
+        self.cache_dir = cache_dir
+        self.max_disk_space = max_disk_space
 
         for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"

+ 30 - 2
src/petals/server/handler.py

@@ -152,6 +152,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 session_id = metadata.get("session_id")
                 alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
                 args_structure = metadata.get("args_structure")
+                active_adapter = self._get_active_adapter(metadata)
                 if not requested_uids:
                     raise ValueError("User must specify at least one block for inference, but got none")
                 assert isinstance(
@@ -169,12 +170,12 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                 async with self._allocate_cache(
                     requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
-                ) as cache_handles:
+                ) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
                     background_tasks = set()
                     async for output_tensors, can_push in iterate_rpc_inference(
                         requested_uids=requested_uids,
                         requested_backends=requested_backends,
-                        active_adapter=self._get_active_adapter(metadata),
+                        active_adapter=active_adapter,
                         input_iterator=self._iterate_inference_steps(
                             request, requests, session_id, requested_uids, context
                         ),
@@ -546,6 +547,33 @@ class TransformerConnectionHandler(ConnectionHandler):
         async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
             yield nested_pack(handles, descriptors)
 
+    @contextlib.asynccontextmanager
+    async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
+        if active_adapter == "":
+            yield
+        elif active_adapter in self.adapters:
+            yield
+        else:
+            try:
+                _peft_module = backends[0]._peft_module
+                token = None  # TODO: Provide token from user request maybe?
+
+                for backend in backends:
+                    adapter_config, adapter_state_dict = _peft_module.load_peft(
+                        active_adapter,
+                        block_idx=backend.block_index,
+                        token=token,
+                        cache_dir=backend.cache_dir,
+                        max_disk_space=backend.max_disk_space,
+                    )
+
+                    _peft_module.add_adapter_to_block(
+                        backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
+                    )
+            finally:
+                for backend in backends:
+                    _peft_module.remove_adapter_from_block(backend.module, active_adapter)
+
     def _log_request(
         self,
         method: str,

+ 1 - 0
src/petals/server/server.py

@@ -512,6 +512,7 @@ class ModuleContainer(threading.Thread):
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
+                    block_index=block_index,
                     config=block_config,
                     memory_cache=memory_cache,
                     backend_dtype=torch_dtype,

+ 4 - 3
src/petals/utils/convert_block.py

@@ -58,10 +58,11 @@ def convert_block(
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
 
-    if adapters:
-        from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
+    from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
+
+    create_lora_adapter(block, quant_type=quant_type)
 
-        create_lora_adapter(block, quant_type=quant_type)
+    if adapters:
         for adapter_name in adapters:
             adapter_config, adapter_state_dict = load_peft(
                 adapter_name,

+ 16 - 0
src/petals/utils/peft.py

@@ -267,6 +267,22 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
     logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
 
 
+def remove_adapter_from_block(block, adapter_name):
+    for _, module in block.named_modules():
+        for child_name, child in module.named_children():
+            if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
+                continue
+
+            if adapter_name in child.lora_A:
+                del child.lora_A[adapter_name]
+            if adapter_name in child.lora_B:
+                del child.lora_B[adapter_name]
+
+            # TODO: check is this needed
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+
+
 def estimate_adapter_memory_per_block(
     block_config: transformers.PretrainedConfig,
     torch_dtype: Optional[torch.dtype],