فهرست منبع

make a test for remote sequential with grads

justheuristic 3 سال پیش
والد
کامیت
c597f4d520
3فایلهای تغییر یافته به همراه66 افزوده شده و 17 حذف شده
  1. 1 0
      .github/workflows/run-tests.yaml
  2. 11 17
      src/client/remote_sequential.py
  3. 54 0
      tests/test_remote_sequential.py

+ 1 - 0
.github/workflows/run-tests.yaml

@@ -86,6 +86,7 @@ jobs:
 
           REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
           
+          pytest tests/test_remote_sequential.py
           REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
           
           kill -s SIGINT $SERVER1_PID $SERVER2_PID

+ 11 - 17
src/client/remote_sequential.py

@@ -28,40 +28,34 @@ class RemoteSequential(nn.Module):
         self,
         config: src.DistributedBloomConfig,
         dht: DHT,
-        prefix: str,
+        dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
     ):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
-        if prefix.endswith(UID_DELIMITER):
-            logger.warning(
-                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
-                f"This will cause {self.__class__.__name__} to look for modules under "
-                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
-            )
-
         super().__init__()
         self.config = config
         self.dht = dht
-        self.prefix = prefix
+        self.dht_prefix = dht_prefix or config.dht_prefix
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
-        block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+        num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
+        block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
             self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
             self.is_subsequence = False
         else:
+            logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
+            self.sequence_manager = sequence_manager
             assert isinstance(sequence_manager.block_uids, list)
-            logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
             self.is_subsequence = self.sequence_manager.block_uids == block_uids
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
-        for block_index in range(self.config.n_layer):
+        for block in iter(self):
             for retry_index in range(self.sequence_manager.max_retries):
                 try:
-                    block = self[block_index]
                     (outputs,) = block(inputs)
                     assert isinstance(outputs, torch.Tensor)
                     assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
@@ -77,20 +71,20 @@ class RemoteSequential(nn.Module):
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
-            assert 0 <= ix < self.config.n_layer
+            assert 0 <= ix < len(self)
             (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
             return module
         else:
             return RemoteSequential(
                 self.config,
                 self.dht,
-                prefix=self.prefix,
+                dht_prefix=self.dht_prefix,
                 p2p=self.p2p,
                 sequence_manager=self.sequence_manager[ix],
             )
 
     def __iter__(self):
-        for block_index in range(self.config.n_layer):
+        for block_index in range(len(self)):
             yield self[block_index]
 
     def __len__(self):
@@ -101,4 +95,4 @@ class RemoteSequential(nn.Module):
         return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
 
     def extra_repr(self) -> str:
-        return f"{self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
+        return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 54 - 0
tests/test_remote_sequential.py

@@ -0,0 +1,54 @@
+import os
+
+import torch
+import transformers
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+
+from src import RemoteSequential
+from src.client.remote_model import DistributedBloomForCausalLM, DistributedBloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+
+def test_remote_sequential():
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+    test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
+    grad_proj = torch.randn(1, 5, config.hidden_size)
+
+    sequential = RemoteSequential(config, dht)
+
+    full_outputs = sequential(test_inputs)
+    (full_outputs * grad_proj).sum().backward()
+    assert test_inputs.grad is not None
+    full_grad = test_inputs.grad.clone()
+    test_inputs.grad.data.zero_()
+
+    first_half = sequential[:config.n_layer // 2]
+    second_half = sequential[config.n_layer // 2:]
+    assert len(first_half) + len(second_half) == len(sequential)
+    assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
+    for m in sequential, first_half, second_half:
+        assert isinstance(repr(m), str)
+
+    hidden = first_half(test_inputs)
+    assert isinstance(hidden, torch.Tensor)
+    assert hidden.shape == test_inputs.shape
+    assert hidden.requires_grad
+    second_half_outputs = second_half(hidden)
+    assert torch.allclose(second_half_outputs, full_outputs)
+
+    (second_half_outputs * grad_proj).sum().backward()
+    assert torch.allclose(test_inputs.grad, full_grad)