Quellcode durchsuchen

remove transformer block, implement as sequence size 1

Pavel Samygin vor 3 Jahren
Ursprung
Commit
5149f47ca7

+ 1 - 1
src/__init__.py

@@ -1,6 +1,6 @@
 from src.bloom import *
 from src.client import *
-from src.dht_utils import declare_active_modules, get_remote_module
+from src.dht_utils import declare_active_modules
 
 project_name = "bloomd"
 __version__ = "0.2"

+ 1 - 2
src/client/__init__.py

@@ -1,5 +1,4 @@
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
-from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 0 - 40
src/client/remote_block.py

@@ -1,40 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-from __future__ import annotations
-
-import random
-
-import torch
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.p2p import P2P, StubBase
-from hivemind.utils import get_logger, use_hivemind_log_handler
-
-from src.client.inference_session import RemoteTransformerBlockInferenceSession
-from src.data_structures import RemoteModuleInfo
-from src.server.handler import TransformerConnectionHandler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class RemoteTransformerBlock(RemoteExpert):
-    """A class that interacts with a remote module on a specific server for forward/backward or inference"""
-
-    def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
-        super().__init__(peer_info, p2p)
-
-    @property
-    def stub(self) -> StubBase:
-        return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
-
-    def forward(self, inputs: torch.Tensor, **kwargs):
-        for k, v in kwargs.items():
-            assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
-        return super().forward(inputs)
-
-    def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
-        """Initialize a new inference session with the specified remote server"""
-        return RemoteExpertWorker.run_coroutine(
-            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
-        )

+ 10 - 5
src/client/remote_sequential.py

@@ -10,7 +10,6 @@ from torch import nn
 
 import src
 from src.client.inference_session import RemoteSequentialInferenceSession
-from src.client.remote_block import RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
@@ -57,12 +56,10 @@ class RemoteSequential(nn.Module):
         outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs
 
-    def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
+    def __getitem__(self, ix: Union[int, slice]) -> Union["RemoteTransformerBlock", RemoteSequential]:
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
-            assert 0 <= ix < len(self)
-            (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
-            return module
+            return self.__getitem__((ix, ix + 1))
         else:
             return RemoteSequential(
                 self.config,
@@ -85,3 +82,11 @@ class RemoteSequential(nn.Module):
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
+
+
+class RemoteTransformerBlock(RemoteSequential):
+    """Single transformer block hosted by swarm"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert len(self) == 1, "Remote Block is a sequence size 1"

+ 0 - 44
src/dht_utils.py

@@ -8,7 +8,6 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
@@ -72,37 +71,6 @@ async def _declare_active_modules(
     )
 
 
-def get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    expiration_time: Optional[DHTExpiration] = None,
-    return_future: bool = False,
-) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
-    """
-    :param uid_or_uids: find one or more modules with these ids from across the DHT
-    :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
-    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock if found else None]
-    """
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
-    )
-
-    if return_future:
-
-        async def _unpack(infos_future: MPFuture, dht: DHT):
-            p2p = await dht.replicate_p2p()
-            modules = _create_remote_modules_from_infos(await infos_future, p2p)
-            return modules[0] if single_uid else modules
-
-        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
-    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    modules = _create_remote_modules_from_infos(infos, p2p)
-    return modules[0] if single_uid else modules
-
-
 def get_remote_module_infos(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
@@ -149,15 +117,3 @@ async def _get_remote_module_infos(
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
     return modules
-
-
-def _create_remote_modules_from_infos(
-    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
-) -> List[Optional[src.RemoteTransformerBlock]]:
-    modules: List[Optional[src.RemoteTransformerBlock]] = []
-    for info in infos:
-        if info is not None:
-            modules.append(src.RemoteTransformerBlock(info, p2p))
-        else:
-            modules.append(None)
-    return modules

+ 5 - 7
tests/test_block_exact_match.py

@@ -7,10 +7,9 @@ import transformers
 from hivemind import P2PHandlerError
 from test_utils import *
 
+import src
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.data_structures import UID_DELIMITER
-from src.dht_utils import get_remote_module
+from src.client.remote_sequential import RemoteSequential
 
 
 @pytest.mark.forked
@@ -18,11 +17,10 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
 
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
     for block_index in random.sample(range(config.n_layer), 3):
-        block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
-        remote_block = get_remote_module(dht, block_uid)
-        assert remote_block is not None, f"Could not find {block_uid} in DHT"
-        assert isinstance(remote_block, RemoteTransformerBlock)
+        remote_block = RemoteSequential(config, dht)[block_index]
+        assert isinstance(remote_block, RemoteSequential)
 
         inputs = torch.randn(1, 8, config.hidden_size)
         (outputs_forward,) = remote_block(inputs)

+ 6 - 6
tests/test_chained_calls.py

@@ -11,18 +11,18 @@ import transformers
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from test_utils import *
 
+import src
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
+from src.client.remote_sequential import RemoteSequential
 
 
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
+    assert isinstance(remote_block, RemoteSequential)
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
@@ -53,9 +53,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
+    assert isinstance(remote_block, RemoteSequential)
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)