浏览代码

reimplement get_remote_module

Pavel Samygin 3 年之前
父节点
当前提交
3cf325004d

+ 4 - 4
README.md

@@ -37,18 +37,18 @@ Then open a python notebook or console and run:
 ```python
 import torch
 import hivemind
-from src import get_remote_module
+from src import DistributedBloomConfig, get_remote_module
 
 
 dht = hivemind.DHT(
     initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
     client_mode=True, start=True,
 )
-
-layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'])
+config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
+layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
 assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
 # test forward/backward, two blocks
-outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
+outputs = layer4(layer3(torch.randn(1, 64, 4096)))
 loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 

+ 1 - 1
src/__init__.py

@@ -1,6 +1,6 @@
 from src.bloom import *
 from src.client import *
-from src.dht_utils import declare_active_modules
+from src.dht_utils import declare_active_modules, get_remote_module
 
 project_name = "bloomd"
 __version__ = "0.2"

+ 7 - 3
src/client/remote_sequential.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import logging
 from typing import Optional, Union
 
 import torch
@@ -13,7 +12,6 @@ from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
-from src.dht_utils import _create_remote_modules_from_infos
 from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
@@ -59,7 +57,13 @@ class RemoteSequential(nn.Module):
     def __getitem__(self, ix: Union[int, slice]) -> Union["RemoteTransformerBlock", RemoteSequential]:
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
-            return self.__getitem__((ix, ix + 1))
+            return RemoteTransformerBlock(
+                self.config,
+                self.dht,
+                dht_prefix=self.dht_prefix,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
         else:
             return RemoteSequential(
                 self.config,

+ 12 - 9
src/client/sequence_manager.py

@@ -82,6 +82,7 @@ class RemoteSequenceManager:
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
+                continue
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
             if not info.servers:
@@ -95,22 +96,24 @@ class RemoteSequenceManager:
         closed_spans = []
         active_spans = {}
         for block_index, info in enumerate(block_infos):
-            for peer_id, server in info.servers.items():
-                if server.state != ServerState.ONLINE:
-                    continue
-                if peer_id not in active_spans:
-                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
-                else:  # peer_id in active_spans
-                    active_spans[peer_id].end = block_index + 1
+            if info is not None:
+                for peer_id, server in info.servers.items():
+                    if server.state != ServerState.ONLINE:
+                        continue
+                    if peer_id not in active_spans:
+                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    else:  # peer_id in active_spans
+                        active_spans[peer_id].end = block_index + 1
 
             for peer_id in list(active_spans.keys()):
                 if (
-                    peer_id not in info.servers
+                    info is None
+                    or peer_id not in info.servers
                     or info.servers[peer_id].state != ServerState.ONLINE
                     or block_index == len(block_infos) - 1
                 ):
                     closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans
+        assert not active_spans, f"spans: {active_spans}"
 
         closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
 

+ 1 - 1
src/client/sequential_autograd.py

@@ -110,7 +110,7 @@ async def sequential_forward(
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     """
 
-    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)

+ 36 - 1
src/dht_utils.py

@@ -8,7 +8,8 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.p2p import P2P, PeerID
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
@@ -71,6 +72,40 @@ async def _declare_active_modules(
     )
 
 
+def get_remote_module(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+    return_future: bool = False,
+) -> List[src.RemoteTransformerBlock]:
+    """
+    :param uid_or_uids: find one or more modules with these ids from across the DHT
+    :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteTransformerBlock]
+    """
+    return RemoteExpertWorker.run_coroutine(
+        _get_distinct_blocks(dht, uid_or_uids, config, dht_prefix), return_future=return_future
+    )
+
+
+async def _get_distinct_blocks(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+):
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    p2p = await dht.replicate_p2p()
+    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    modules = [
+        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+    ]
+    return modules[0] if single_uid else modules
+
+
 def get_remote_module_infos(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],

+ 6 - 5
tests/test_block_exact_match.py

@@ -9,18 +9,19 @@ from test_utils import *
 
 import src
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import get_remote_module
 
 
 @pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-
     config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+
     for block_index in random.sample(range(config.n_layer), 3):
-        remote_block = RemoteSequential(config, dht)[block_index]
-        assert isinstance(remote_block, RemoteSequential)
+        remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
+        assert isinstance(remote_block, RemoteTransformerBlock)
 
         inputs = torch.randn(1, 8, config.hidden_size)
         (outputs_forward,) = remote_block(inputs)

+ 9 - 7
tests/test_chained_calls.py

@@ -13,16 +13,18 @@ from test_utils import *
 
 import src
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import get_remote_module
 
 
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteSequential)
+    assert isinstance(remote_block, RemoteTransformerBlock)
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
@@ -52,10 +54,10 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteSequential)
+    assert isinstance(remote_block, RemoteTransformerBlock)
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)