Browse Source

remove transformer block, implement as sequence size 1

Pavel Samygin 3 năm trước cách đây
mục cha
commit
5149f47ca7

+ 1 - 1
src/__init__.py

@@ -1,6 +1,6 @@
 from src.bloom import *
 from src.bloom import *
 from src.client import *
 from src.client import *
-from src.dht_utils import declare_active_modules, get_remote_module
+from src.dht_utils import declare_active_modules
 
 
 project_name = "bloomd"
 project_name = "bloomd"
 __version__ = "0.2"
 __version__ = "0.2"

+ 1 - 2
src/client/__init__.py

@@ -1,5 +1,4 @@
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
-from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager

+ 0 - 40
src/client/remote_block.py

@@ -1,40 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-from __future__ import annotations
-
-import random
-
-import torch
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.p2p import P2P, StubBase
-from hivemind.utils import get_logger, use_hivemind_log_handler
-
-from src.client.inference_session import RemoteTransformerBlockInferenceSession
-from src.data_structures import RemoteModuleInfo
-from src.server.handler import TransformerConnectionHandler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class RemoteTransformerBlock(RemoteExpert):
-    """A class that interacts with a remote module on a specific server for forward/backward or inference"""
-
-    def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
-        super().__init__(peer_info, p2p)
-
-    @property
-    def stub(self) -> StubBase:
-        return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
-
-    def forward(self, inputs: torch.Tensor, **kwargs):
-        for k, v in kwargs.items():
-            assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
-        return super().forward(inputs)
-
-    def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
-        """Initialize a new inference session with the specified remote server"""
-        return RemoteExpertWorker.run_coroutine(
-            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
-        )

+ 10 - 5
src/client/remote_sequential.py

@@ -10,7 +10,6 @@ from torch import nn
 
 
 import src
 import src
 from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.inference_session import RemoteSequentialInferenceSession
-from src.client.remote_block import RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
@@ -57,12 +56,10 @@ class RemoteSequential(nn.Module):
         outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs
         return outputs
 
 
-    def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
+    def __getitem__(self, ix: Union[int, slice]) -> Union["RemoteTransformerBlock", RemoteSequential]:
         assert isinstance(ix, (int, slice))
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
         if isinstance(ix, int):
-            assert 0 <= ix < len(self)
-            (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
-            return module
+            return self.__getitem__((ix, ix + 1))
         else:
         else:
             return RemoteSequential(
             return RemoteSequential(
                 self.config,
                 self.config,
@@ -85,3 +82,11 @@ class RemoteSequential(nn.Module):
 
 
     def extra_repr(self) -> str:
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
+
+
+class RemoteTransformerBlock(RemoteSequential):
+    """Single transformer block hosted by swarm"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert len(self) == 1, "Remote Block is a sequence size 1"

+ 0 - 44
src/dht_utils.py

@@ -8,7 +8,6 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Union
 from typing import Dict, List, Optional, Sequence, Union
 
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 
@@ -72,37 +71,6 @@ async def _declare_active_modules(
     )
     )
 
 
 
 
-def get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    expiration_time: Optional[DHTExpiration] = None,
-    return_future: bool = False,
-) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
-    """
-    :param uid_or_uids: find one or more modules with these ids from across the DHT
-    :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
-    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock if found else None]
-    """
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
-    )
-
-    if return_future:
-
-        async def _unpack(infos_future: MPFuture, dht: DHT):
-            p2p = await dht.replicate_p2p()
-            modules = _create_remote_modules_from_infos(await infos_future, p2p)
-            return modules[0] if single_uid else modules
-
-        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
-    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    modules = _create_remote_modules_from_infos(infos, p2p)
-    return modules[0] if single_uid else modules
-
-
 def get_remote_module_infos(
 def get_remote_module_infos(
     dht: DHT,
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
@@ -149,15 +117,3 @@ async def _get_remote_module_infos(
         if servers:
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
             modules[i] = RemoteModuleInfo(uid, servers)
     return modules
     return modules
-
-
-def _create_remote_modules_from_infos(
-    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
-) -> List[Optional[src.RemoteTransformerBlock]]:
-    modules: List[Optional[src.RemoteTransformerBlock]] = []
-    for info in infos:
-        if info is not None:
-            modules.append(src.RemoteTransformerBlock(info, p2p))
-        else:
-            modules.append(None)
-    return modules

+ 5 - 7
tests/test_block_exact_match.py

@@ -7,10 +7,9 @@ import transformers
 from hivemind import P2PHandlerError
 from hivemind import P2PHandlerError
 from test_utils import *
 from test_utils import *
 
 
+import src
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.data_structures import UID_DELIMITER
-from src.dht_utils import get_remote_module
+from src.client.remote_sequential import RemoteSequential
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -18,11 +17,10 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
 
 
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
     for block_index in random.sample(range(config.n_layer), 3):
     for block_index in random.sample(range(config.n_layer), 3):
-        block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
-        remote_block = get_remote_module(dht, block_uid)
-        assert remote_block is not None, f"Could not find {block_uid} in DHT"
-        assert isinstance(remote_block, RemoteTransformerBlock)
+        remote_block = RemoteSequential(config, dht)[block_index]
+        assert isinstance(remote_block, RemoteSequential)
 
 
         inputs = torch.randn(1, 8, config.hidden_size)
         inputs = torch.randn(1, 8, config.hidden_size)
         (outputs_forward,) = remote_block(inputs)
         (outputs_forward,) = remote_block(inputs)

+ 6 - 6
tests/test_chained_calls.py

@@ -11,18 +11,18 @@ import transformers
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from test_utils import *
 from test_utils import *
 
 
+import src
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
+from src.client.remote_sequential import RemoteSequential
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
+    assert isinstance(remote_block, RemoteSequential)
 
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
@@ -53,9 +53,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 def test_chained_inference_exact_match(atol_inference=1e-4):
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
     assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
+    assert isinstance(remote_block, RemoteSequential)
 
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)