Ver código fonte

merge test_chained_forward_backward.py and test_chained_inference into one file

Dmitry Baranchuk 3 anos atrás
pai
commit
d82e2b9af5

+ 3 - 1
.github/workflows/run-tests.yaml

@@ -79,7 +79,9 @@ jobs:
           export PYTHONPATH=. REF_NAME=bloom-testing/test-bloomd-350m
           BLOCK_UID=bloom-testing/test-bloomd-350m.0 REF_INDEX=0 pytest tests/test_block_exact_match.py
           BLOCK_UID=bloom-testing/test-bloomd-350m.19 REF_INDEX=19 pytest tests/test_block_exact_match.py
-          unser REF_NAME
+          
+          BLOCK_UID=bloom-testing/test-bloomd-350m.3 REF_INDEX=3 pytest tests/test_chained_forward_backward.py
+          unset REF_NAME
           MODEL_NAME=bloom-testing/test-bloomd-350m REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
           
           kill -s SIGINT $SERVER1_PID $SERVER2_PID

+ 97 - 0
tests/test_chained_calls.py

@@ -0,0 +1,97 @@
+######
+# Warning:torch this test is a work in progress. It will be modified soon.
+# - if you want more stable tests, see test_block_exact_match
+# - if you want to figure out chained inference, ask yozh
+
+import os
+
+import hivemind
+import torch
+import transformers
+from hivemind.moe.expert_uid import ExpertInfo, UID_DELIMITER
+
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_block import RemoteTransformerBlock
+from src.dht_utils import get_remote_module
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
+
+REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
+
+
+def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
+    assert isinstance(remote_block, RemoteTransformerBlock)
+
+    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
+    remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
+
+    ref_blocks = [
+        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
+    ]
+    inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
+    outputs_rpc = remote_block.forward(inputs)[0]
+    outputs_rpc.sum().backward()
+    grads_rpc = inputs.grad
+
+    inputs.grad = None
+    hidden_states = inputs
+    for ref_block in ref_blocks:
+        hidden_states = ref_block.forward(hidden_states)[0]
+    outputs_ref = hidden_states
+    outputs_ref.sum().backward()
+    grads_ref = inputs.grad
+
+    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
+    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
+
+
+def test_chained_inference_exact_match(atol_inference=1e-4):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
+    assert isinstance(remote_block, RemoteTransformerBlock)
+
+    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
+    remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
+
+    inputs = torch.randn(1, 8, config.hidden_size)
+
+    outputs_inference = []
+    with remote_block.inference_session() as sess:
+        for i in range(inputs.shape[1]):
+            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+    outputs_inference = torch.cat(outputs_inference, dim=1)
+
+    ref_blocks = [
+        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+    ]
+    outputs_ref = []
+    caches = [None, None]
+    for i in range(inputs.shape[1]):
+        new_caches = []
+        hidden_states = inputs[:, i : i + 1, :]
+        for ref_block, cache in zip(ref_blocks, caches):
+            with torch.no_grad():
+                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
+                new_caches.append(new_cache)
+
+        outputs_ref.append(hidden_states)
+        caches = new_caches
+    outputs_ref = torch.cat(outputs_ref, dim=1)
+    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)

+ 0 - 59
tests/test_chained_forward_backward.py

@@ -1,59 +0,0 @@
-######
-# Warning:torch this test is a work in progress. It will be modified soon.
-# - if you want more stable tests, see test_block_exact_match
-# - if you want to figure out chained inference, ask yozh
-
-import os
-
-import hivemind
-import torch
-from hivemind.moe.expert_uid import ExpertInfo
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-
-
-# seq_length > 128: rpc_forward_stream & rpc_backward_stream
-# seq_length <= 128: rpc_forward & rpc_backward
-def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    (remote_block,) = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
-
-    ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
-    ]
-    inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
-    outputs_rpc = remote_block.forward(inputs)[0]
-    outputs_rpc.sum().backward()
-    grads_rpc = inputs.grad
-
-    inputs.grad = None
-    hidden_states = inputs
-    for ref_block in ref_blocks:
-        hidden_states = ref_block.forward(hidden_states)[0]
-    outputs_ref = hidden_states
-    outputs_ref.sum().backward()
-    grads_ref = inputs.grad
-
-    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
-    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)

+ 0 - 64
tests/test_chained_inference.py

@@ -1,64 +0,0 @@
-######
-# Warning:torch this test is a work in progress. It will be modified soon.
-# - if you want more stable tests, see test_block_exact_match
-# - if you want to figure out chained inference, ask yozh
-
-import os
-
-import hivemind
-import torch
-from hivemind.moe.expert_uid import ExpertInfo
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
-
-
-def test_remote_block_exact_match(atol_inference=1e-4):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id)
-
-    inputs = torch.randn(1, 8, 4096)
-
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
-
-    ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-    ]
-    outputs_ref = []
-    caches = [None, None]
-    for i in range(inputs.shape[1]):
-        new_caches = []
-        hidden_states = inputs[:, i : i + 1, :]
-        for ref_block, cache in zip(ref_blocks, caches):
-            with torch.no_grad():
-                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
-                new_caches.append(new_cache)
-
-        outputs_ref.append(hidden_states)
-        caches = new_caches
-    outputs_ref = torch.cat(outputs_ref, dim=1)
-    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)