Răsfoiți Sursa

fix chained calls

Pavel Samygin 3 ani în urmă
părinte
comite
a0e3c8943b
2 a modificat fișierele cu 35 adăugiri și 14 ștergeri
  1. 27 1
      src/dht_utils.py
  2. 8 13
      tests/test_chained_calls.py

+ 27 - 1
src/dht_utils.py

@@ -72,13 +72,39 @@ async def _declare_active_modules(
     )
 
 
+def get_remote_sequence(
+    dht: DHT,
+    start: int,
+    stop: int,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+    return_future: bool = False,
+) -> Union[src.RemoteSequential, MPFuture]:
+    return RemoteExpertWorker.run_coroutine(
+        _get_sequence_blocks(dht, start, stop, config, dht_prefix), return_future=return_future
+    )
+
+
+async def _get_sequence_blocks(
+    dht: DHT,
+    start: int,
+    stop: int,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+) -> src.RemoteSequential:
+    uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
+    p2p = await dht.replicate_p2p()
+    manager = src.RemoteSequenceManager(dht, uids, p2p)
+    return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+
+
 def get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
     config: src.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)

+ 8 - 13
tests/test_chained_calls.py

@@ -7,24 +7,20 @@
 import hivemind
 import pytest
 import torch
-import transformers
-from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from test_utils import *
 
 import src
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteTransformerBlock
-from src.data_structures import UID_DELIMITER
-from src.dht_utils import get_remote_module
+from src.client.remote_sequential import RemoteSequential
+from src.dht_utils import get_remote_sequence
 
 
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
-    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
+    remote_blocks = get_remote_sequence(dht, 3, 6, config)
+    assert isinstance(remote_blocks, RemoteSequential)
 
     ref_blocks = [
         load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
@@ -32,7 +28,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
         load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
     ]
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
-    outputs_rpc = remote_block.forward(inputs)[0]
+    outputs_rpc = remote_blocks.forward(inputs)
     outputs_rpc.sum().backward()
     grads_rpc = inputs.grad
 
@@ -52,14 +48,13 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
-    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
+    remote_blocks = get_remote_sequence(dht, 3, 5, config)
+    assert isinstance(remote_blocks, RemoteSequential)
 
     inputs = torch.randn(1, 8, config.hidden_size)
 
     outputs_inference = []
-    with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+    with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)