瀏覽代碼

Add memory cache usage

Artem Chumachenko 1 年之前
父節點
當前提交
1b21dd3217
共有 5 個文件被更改,包括 64 次插入33 次删除
  1. 4 4
      src/petals/server/backend.py
  2. 44 23
      src/petals/server/handler.py
  3. 7 1
      src/petals/server/server.py
  4. 0 5
      src/petals/utils/dht.py
  5. 9 0
      src/petals/utils/peft.py

+ 4 - 4
src/petals/server/backend.py

@@ -57,15 +57,15 @@ class TransformerBackend(ModuleBackend):
             assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
 
         max_batch_size = self.forward_pool.max_batch_size
-        device = self.module.devices[self.module.output_device_index]
+        self.device = self.module.devices[self.module.output_device_index]
         self.inference_pool = PrioritizedTaskPool(
-            self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
+            self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference"
         )  # note: inference_pools may be merged later, see merge_inference_pools_inplace
         self.forward_pool = PrioritizedTaskPool(
-            self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
+            self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward"
         )
         self.backward_pool = PrioritizedTaskPool(
-            self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
+            self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward"
         )
 
         self.dtype = backend_dtype

+ 44 - 23
src/petals/server/handler.py

@@ -15,6 +15,7 @@ from hivemind import (
     MSGPackSerializer,
     P2PContext,
     PeerID,
+    TensorDescriptor,
     deserialize_tensor_stream,
     deserialize_torch_tensor,
     nested_flatten,
@@ -170,7 +171,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
                 async with self._allocate_cache(
                     requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
-                ) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
+                ) as cache_handles, self._load_peft_module(
+                    requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
+                ):
                     background_tasks = set()
                     async for output_tensors, can_push in iterate_rpc_inference(
                         requested_uids=requested_uids,
@@ -490,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     def _get_active_adapter(self, metadata: dict) -> str:
         active_adapter = metadata.get("active_adapter", "")
-        if active_adapter and (active_adapter not in self.adapters):
-            raise KeyError(f"adapter {active_adapter} not found")
-        return active_adapter
+        if active_adapter:
+            return active_adapter
+        return ""
 
     def _serialize_grads(
         self,
@@ -548,31 +551,49 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield nested_pack(handles, descriptors)
 
     @contextlib.asynccontextmanager
-    async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
+    async def _load_peft_module(
+        self,
+        backends: Sequence[TransformerBackend],
+        *,
+        active_adapter: str,
+        timeout: float,
+    ):
         if active_adapter == "":
             yield
         elif active_adapter in self.adapters:
             yield
         else:
-            try:
-                _peft_module = backends[0]._peft_module
-                token = None  # TODO: Provide token from user request maybe?
-
-                for backend in backends:
-                    adapter_config, adapter_state_dict = _peft_module.load_peft(
-                        active_adapter,
-                        block_idx=backend.block_index,
-                        token=token,
-                        cache_dir=backend.cache_dir,
-                        max_disk_space=backend.max_disk_space,
-                    )
+            _peft_module = backends[0]._peft_module
+            token = None  # TODO: Provide token from user request maybe?
 
-                    _peft_module.add_adapter_to_block(
-                        backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
-                    )
-            finally:
-                for backend in backends:
-                    _peft_module.remove_adapter_from_block(backend.module, active_adapter)
+            estimated_peft_size = _peft_module.get_estimated_peft_module_size(
+                active_adapter,
+                token=token,
+            )
+
+            fake_descriptor = TensorDescriptor(
+                size=(estimated_peft_size,),
+                dtype=torch.int8,
+                device=backends[0].device,
+            )
+
+            async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
+                try:
+                    for backend in backends:
+                        adapter_config, adapter_state_dict = _peft_module.load_peft(
+                            active_adapter,
+                            block_idx=backend.block_index,
+                            token=token,
+                            cache_dir=backend.cache_dir,
+                            max_disk_space=backend.max_disk_space,
+                        )
+
+                        _peft_module.add_adapter_to_block(
+                            backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
+                        )
+                finally:
+                    for backend in backends:
+                        _peft_module.remove_adapter_from_block(backend.module, active_adapter)
 
     def _log_request(
         self,

+ 7 - 1
src/petals/server/server.py

@@ -231,6 +231,8 @@ class Server:
         gib = 1024**3
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
+        self.adapters_cache_bytes = self.attn_cache_bytes
+        logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB")
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
         if throughput in ["auto", "eval", "dry_run"]:
@@ -335,6 +337,7 @@ class Server:
                 converted_model_name_or_path=self.converted_model_name_or_path,
                 block_config=self.block_config,
                 attn_cache_bytes=self.attn_cache_bytes,
+                adapters_cache_bytes=self.adapters_cache_bytes,
                 server_info=self.server_info,
                 model_info=self.model_info,
                 block_indices=block_indices,
@@ -442,6 +445,7 @@ class ModuleContainer(threading.Thread):
         converted_model_name_or_path: str,
         block_config: PretrainedConfig,
         attn_cache_bytes: int,
+        adapters_cache_bytes: int,
         server_info: ServerInfo,
         model_info: ModelInfo,
         block_indices: List[int],
@@ -464,7 +468,7 @@ class ModuleContainer(threading.Thread):
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
-        memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
+        memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout)
 
         server_info.state = ServerState.JOINING
         dht_announcer = ModuleAnnouncerThread(
@@ -517,6 +521,8 @@ class ModuleContainer(threading.Thread):
                     memory_cache=memory_cache,
                     backend_dtype=torch_dtype,
                     max_chunk_size_bytes=max_chunk_size_bytes,
+                    cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
                     args_schema=(
                         BatchTensorDescriptor(
                             1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression

+ 0 - 5
src/petals/utils/dht.py

@@ -111,11 +111,6 @@ async def _get_remote_module_infos(
             try:
                 peer_id = PeerID.from_base58(peer_id)
                 server_info = ServerInfo.from_tuple(server_info.value)
-
-                if active_adapter and active_adapter not in server_info.adapters:
-                    logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
-                    continue
-
                 servers[peer_id] = server_info
             except (TypeError, ValueError) as e:
                 logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")

+ 9 - 0
src/petals/utils/peft.py

@@ -128,6 +128,15 @@ def load_peft(
             time.sleep(delay)
 
 
+def get_estimated_peft_module_size(
+        repo_id: str,
+        revision: Optional[str] = None,
+        token: Optional[Union[str, bool]] = None,
+):
+    weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
+    return get_hf_file_metadata(weight_url, token=token).size
+
+
 class AdapterContextMixin:
     """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""