Browse Source

reimplement get_remote_module

Pavel Samygin 3 years ago
parent
commit
3cf325004d

+ 4 - 4
README.md

@@ -37,18 +37,18 @@ Then open a python notebook or console and run:
 ```python
 ```python
 import torch
 import torch
 import hivemind
 import hivemind
-from src import get_remote_module
+from src import DistributedBloomConfig, get_remote_module
 
 
 
 
 dht = hivemind.DHT(
 dht = hivemind.DHT(
     initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
     initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
     client_mode=True, start=True,
     client_mode=True, start=True,
 )
 )
-
-layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'])
+config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
+layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
 assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
 assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
 # test forward/backward, two blocks
 # test forward/backward, two blocks
-outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
+outputs = layer4(layer3(torch.randn(1, 64, 4096)))
 loss = (outputs * torch.randn_like(outputs)).norm()
 loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 loss.backward()
 
 

+ 1 - 1
src/__init__.py

@@ -1,6 +1,6 @@
 from src.bloom import *
 from src.bloom import *
 from src.client import *
 from src.client import *
-from src.dht_utils import declare_active_modules
+from src.dht_utils import declare_active_modules, get_remote_module
 
 
 project_name = "bloomd"
 project_name = "bloomd"
 __version__ = "0.2"
 __version__ = "0.2"

+ 7 - 3
src/client/remote_sequential.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import logging
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 import torch
 import torch
@@ -13,7 +12,6 @@ from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
-from src.dht_utils import _create_remote_modules_from_infos
 from src.utils.misc import DUMMY
 from src.utils.misc import DUMMY
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
@@ -59,7 +57,13 @@ class RemoteSequential(nn.Module):
     def __getitem__(self, ix: Union[int, slice]) -> Union["RemoteTransformerBlock", RemoteSequential]:
     def __getitem__(self, ix: Union[int, slice]) -> Union["RemoteTransformerBlock", RemoteSequential]:
         assert isinstance(ix, (int, slice))
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
         if isinstance(ix, int):
-            return self.__getitem__((ix, ix + 1))
+            return RemoteTransformerBlock(
+                self.config,
+                self.dht,
+                dht_prefix=self.dht_prefix,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
         else:
         else:
             return RemoteSequential(
             return RemoteSequential(
                 self.config,
                 self.config,

+ 12 - 9
src/client/sequence_manager.py

@@ -82,6 +82,7 @@ class RemoteSequenceManager:
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
                 logger.warning(f"Found no block info for block {uid}")
+                continue
             if not isinstance(info, RemoteModuleInfo):
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
             if not info.servers:
             if not info.servers:
@@ -95,22 +96,24 @@ class RemoteSequenceManager:
         closed_spans = []
         closed_spans = []
         active_spans = {}
         active_spans = {}
         for block_index, info in enumerate(block_infos):
         for block_index, info in enumerate(block_infos):
-            for peer_id, server in info.servers.items():
-                if server.state != ServerState.ONLINE:
-                    continue
-                if peer_id not in active_spans:
-                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
-                else:  # peer_id in active_spans
-                    active_spans[peer_id].end = block_index + 1
+            if info is not None:
+                for peer_id, server in info.servers.items():
+                    if server.state != ServerState.ONLINE:
+                        continue
+                    if peer_id not in active_spans:
+                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    else:  # peer_id in active_spans
+                        active_spans[peer_id].end = block_index + 1
 
 
             for peer_id in list(active_spans.keys()):
             for peer_id in list(active_spans.keys()):
                 if (
                 if (
-                    peer_id not in info.servers
+                    info is None
+                    or peer_id not in info.servers
                     or info.servers[peer_id].state != ServerState.ONLINE
                     or info.servers[peer_id].state != ServerState.ONLINE
                     or block_index == len(block_infos) - 1
                     or block_index == len(block_infos) - 1
                 ):
                 ):
                     closed_spans.append(active_spans.pop(peer_id))
                     closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans
+        assert not active_spans, f"spans: {active_spans}"
 
 
         closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
         closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
 
 

+ 1 - 1
src/client/sequential_autograd.py

@@ -110,7 +110,7 @@ async def sequential_forward(
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     """
     """
 
 
-    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
 
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)

+ 36 - 1
src/dht_utils.py

@@ -8,7 +8,8 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Union
 from typing import Dict, List, Optional, Sequence, Union
 
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.p2p import P2P, PeerID
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 
 import src
 import src
@@ -71,6 +72,40 @@ async def _declare_active_modules(
     )
     )
 
 
 
 
+def get_remote_module(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+    return_future: bool = False,
+) -> List[src.RemoteTransformerBlock]:
+    """
+    :param uid_or_uids: find one or more modules with these ids from across the DHT
+    :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteTransformerBlock]
+    """
+    return RemoteExpertWorker.run_coroutine(
+        _get_distinct_blocks(dht, uid_or_uids, config, dht_prefix), return_future=return_future
+    )
+
+
+async def _get_distinct_blocks(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+):
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    p2p = await dht.replicate_p2p()
+    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    modules = [
+        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+    ]
+    return modules[0] if single_uid else modules
+
+
 def get_remote_module_infos(
 def get_remote_module_infos(
     dht: DHT,
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],

+ 6 - 5
tests/test_block_exact_match.py

@@ -9,18 +9,19 @@ from test_utils import *
 
 
 import src
 import src
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import get_remote_module
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-
     config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
     config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+
     for block_index in random.sample(range(config.n_layer), 3):
     for block_index in random.sample(range(config.n_layer), 3):
-        remote_block = RemoteSequential(config, dht)[block_index]
-        assert isinstance(remote_block, RemoteSequential)
+        remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
+        assert isinstance(remote_block, RemoteTransformerBlock)
 
 
         inputs = torch.randn(1, 8, config.hidden_size)
         inputs = torch.randn(1, 8, config.hidden_size)
         (outputs_forward,) = remote_block(inputs)
         (outputs_forward,) = remote_block(inputs)

+ 9 - 7
tests/test_chained_calls.py

@@ -13,16 +13,18 @@ from test_utils import *
 
 
 import src
 import src
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import get_remote_module
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteSequential)
+    assert isinstance(remote_block, RemoteTransformerBlock)
 
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
@@ -52,10 +54,10 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 @pytest.mark.forked
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteSequential)
+    assert isinstance(remote_block, RemoteTransformerBlock)
 
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)