5
0
Эх сурвалжийг харах

add minimalistic peft test

Your Name 2 жил өмнө
parent
commit
6db11c6483

+ 2 - 1
src/petals/client/remote_sequential.py

@@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
         dht: Optional[DHT] = None,
         start_block: Optional[int] = None,
         end_block: Optional[int] = None,
+        **kwargs
     ):
         super().__init__()
         self.config = config
@@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
             if end_block is None:
                 end_block = self.config.num_hidden_layers
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
-            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
+            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
         self.sequence_manager = sequence_manager
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):

+ 6 - 2
src/petals/client/routing/sequence_manager.py

@@ -78,6 +78,7 @@ class RemoteSequenceManager:
         *,
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
+        extra_metadata: Optional[Dict[str, Any]] = None
     ):
         assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@@ -98,6 +99,7 @@ class RemoteSequenceManager:
             )
         assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
         self.dht = dht
+        self.extra_metadata = extra_metadata if extra_metadata is not None else {}
 
         if state.p2p is None:
             state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
@@ -167,7 +169,9 @@ class RemoteSequenceManager:
         assert isinstance(ix, (int, slice))
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
-        return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
+        return type(self)(
+            self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix], extra_metadata=self.extra_metadata
+        )
 
     def update(self, *, wait: bool):
         """Run an asynchronous update in background as soon as possible"""
@@ -307,7 +311,7 @@ class RemoteSequenceManager:
         :param kwargs: additional request context, such as remote peer ID
         :returns: msgpack-serialized metadata dict that will be passed alongside a given request
         """
-        return dict(points=self.policy.get_points(protocol, *args, **kwargs))
+        return dict(**self.extra_metadata, points=self.policy.get_points(protocol, *args, **kwargs))
 
     def shutdown(self):
         self._thread.shutdown()

+ 25 - 4
src/petals/dht_utils.py

@@ -22,6 +22,7 @@ def declare_active_modules(
     expiration_time: DHTExpiration,
     state: ServerState,
     throughput: float,
+    adapters: Optional[Sequence[str]] = None,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
@@ -39,6 +40,7 @@ def declare_active_modules(
         uids = list(uids)
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
+
     return dht.run_coroutine(
         partial(
             _declare_active_modules,
@@ -46,6 +48,7 @@ def declare_active_modules(
             expiration_time=expiration_time,
             state=state,
             throughput=throughput,
+            adapters=list(adapters or []),
         ),
         return_future=not wait,
     )
@@ -58,12 +61,13 @@ async def _declare_active_modules(
     expiration_time: DHTExpiration,
     state: ServerState,
     throughput: float,
+    adapters: List[str],
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[(state.value, throughput)] * len(uids),
+        values=[(state.value, throughput, adapters)] * len(uids),
         expiration_time=expiration_time,
         num_workers=num_workers,
     )
@@ -73,18 +77,30 @@ def get_remote_module_infos(
     dht: DHT,
     uids: Sequence[ModuleUID],
     expiration_time: Optional[DHTExpiration] = None,
+    active_adapter: Optional[str] = None,
     *,
     latest: bool = False,
     return_future: bool = False,
 ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
     return dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
+        partial(
+            _get_remote_module_infos,
+            uids=uids,
+            active_adapter=active_adapter,
+            expiration_time=expiration_time,
+            latest=latest,
+        ),
         return_future=return_future,
     )
 
 
 async def _get_remote_module_infos(
-    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    active_adapter: Optional[str],
+    expiration_time: Optional[DHTExpiration],
+    latest: bool,
 ) -> List[Optional[RemoteModuleInfo]]:
     if latest:
         assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
@@ -105,7 +121,12 @@ async def _get_remote_module_infos(
         for peer_id, server_info in metadata.value.items():
             try:
                 peer_id = PeerID.from_base58(peer_id)
-                state, throughput = server_info.value
+                state, throughput = server_info.value[:2]
+                available_adapters = server_info.value[2] if len(server_info.value) > 2 else []
+                if active_adapter is not None and active_adapter not in available_adapters:
+                    logger.warning(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
+                    continue
+
                 if not (
                     isinstance(state, int)
                     and isinstance(throughput, float)

+ 10 - 7
src/petals/server/backend.py

@@ -83,13 +83,14 @@ class TransformerBackend(ModuleBackend):
 
     def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        if active_adapter and not self.load_adapter_(active_adapter):
+        print("--forward...")
+        if not self.load_adapter_(active_adapter):
             raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         return super().forward(*inputs)
 
     def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
         *inputs, active_adapter = inputs
-        if active_adapter and not self.load_adapter_(active_adapter):
+        if not self.load_adapter_(active_adapter):
             raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         return super().backward(*inputs)
 
@@ -102,7 +103,8 @@ class TransformerBackend(ModuleBackend):
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
-        if inference_info.active_adapter and not self.load_adapter_(inference_info.active_adapter):
+        print("--inference...")
+        if not self.load_adapter_(inference_info.active_adapter):
             raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)
@@ -156,14 +158,15 @@ class TransformerBackend(ModuleBackend):
             p.data = dummy
 
     def load_adapter_(self, active_adapter: str = "") -> bool:
-        """Try to make a given adapter set active if it was loaded. Return True if loaded, False if no such adapter"""
-        adapter_is_loaded = False
+        """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
+        print("LOADING ADAPTER [", active_adapter, "]")
+        adapter_was_loaded = False
         for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
             if isinstance(layer, peft.tuners.lora.Linear):
                 layer.active_adapter = active_adapter  # empty string for no adapter
                 if active_adapter in layer.lora_A.keys():
-                    adapter_is_loaded = True
-        return adapter_is_loaded
+                    adapter_was_loaded = True
+        return adapter_was_loaded or active_adapter == ""
 
 
 def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):

+ 2 - 0
src/petals/server/handler.py

@@ -356,6 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
             active_adapter = metadata.get("active_adapter", "")
+            print("ACTIVE_ADAPTER: [", active_adapter, "]")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)
@@ -383,6 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
             active_adapter = metadata.get("active_adapter", "")
+            print("ACTIVE_ADAPTER: [", active_adapter, "]")
             points = metadata.get("points", 0)
             assert isinstance(
                 points, (float, int)

+ 8 - 0
src/petals/server/server.py

@@ -396,6 +396,7 @@ class ModuleContainer(threading.Thread):
             module_uids,
             dht,
             ServerState.JOINING,
+            adapters=adapters,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -469,6 +470,7 @@ class ModuleContainer(threading.Thread):
                 expiration_time=get_dht_time() + expiration,
                 state=ServerState.OFFLINE,
                 throughput=throughput,
+                adapters=adapters,
             )
             logger.info(f"Announced that blocks {module_uids} are offline")
             raise
@@ -482,6 +484,7 @@ class ModuleContainer(threading.Thread):
             dht,
             dht_prefix,
             blocks,
+            adapters=adapters,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -497,6 +500,7 @@ class ModuleContainer(threading.Thread):
         inference_max_length: int,
         num_handlers: int,
         throughput: float,
+        adapters: Optional[Sequence[str]],
         update_period: float,
         expiration: Optional[float] = None,
         request_timeout: float,
@@ -534,6 +538,7 @@ class ModuleContainer(threading.Thread):
             list(self.module_backends.keys()),
             dht,
             ServerState.ONLINE,
+            adapters=adapters,
             throughput=throughput,
             update_period=update_period,
             expiration=expiration,
@@ -633,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
         module_uids: List[str],
         dht: DHT,
         state: ServerState,
+        adapters: Optional[Sequence[str]],
         *,
         throughput: float,
         update_period: float = 30,
@@ -643,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
         self.module_uids = module_uids
         self.dht = dht
         self.state = state
+        self.adapters = adapters
         self.throughput = throughput
         self.update_period = update_period
         self.expiration = expiration
@@ -656,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
                 expiration_time=get_dht_time() + self.expiration,
                 state=self.state,
                 throughput=self.throughput,
+                adapters=self.adapters,
             )
             if self.stop.wait(self.update_period):
                 break

+ 2 - 0
src/petals/utils/peft.py

@@ -154,6 +154,8 @@ def create_lora_adapter(block):
                 )
             if lora_wrapped_child:
                 lora_wrapped_child.active_adapter = None
+                lora_wrapped_child.weight = child.weight
+                lora_wrapped_child.bias = child.bias
                 for p in lora_wrapped_child.parameters():
                     p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)

+ 8 - 1
tests/test_full_model.py

@@ -1,6 +1,7 @@
 import pytest
 import torch
 import transformers
+import peft
 from hivemind import get_logger
 from transformers.generation import BeamSearchScorer
 from transformers.models.bloom import BloomForCausalLM
@@ -12,12 +13,16 @@ logger = get_logger(__name__)
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
-def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
+def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
+    if use_peft:
+        model.transformer.h.sequence_manager.extra_metadata = dict(active_adapter=ADAPTER_NAME)
+
     config = model.config
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.num_hidden_layers
@@ -54,6 +59,8 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
             ref_model = transformers.BloomForCausalLM.from_pretrained(
                 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
             )
+            if use_peft:
+                ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
             if config.vocab_size < ref_model.config.vocab_size:
                 ref_model.resize_token_embeddings(config.vocab_size)
                 logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

+ 2 - 0
tests/test_utils.py

@@ -11,3 +11,5 @@ if not MODEL_NAME:
     raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
 
 REF_NAME = os.environ.get("REF_NAME")
+
+ADAPTER_NAME = os.environ.get("PEFT_NAME")