Ver Fonte

support hosting multiple instances of the same block

justheuristic há 3 anos atrás
pai
commit
2eb47cbedd

+ 1 - 1
README.md

@@ -49,7 +49,7 @@ Then open a python notebook or console and run:
 ```python
 import torch
 import hivemind
-from src.client.remote_block import get_remote_module
+from src import get_remote_module
 
 dht = hivemind.DHT(
     initial_peers=["/ip4/127.0.0.1/COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS"],

+ 2 - 0
src/__init__.py

@@ -1 +1,3 @@
 from .bloom import *
+from .dht_utils import get_remote_module, declare_active_modules
+from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession

+ 1 - 1
src/client/__init__.py

@@ -1 +1 @@
-from src.client.remote_block import RemoteTransformerBlock
+from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession

+ 0 - 29
src/client/inference_chain.py

@@ -1,29 +0,0 @@
-from collections import defaultdict
-from typing import Sequence
-
-import torch
-from hivemind import DHT
-from torch import nn
-
-from src import DistributedBloomConfig
-from src.server.backend import MAX_LENGTH
-
-
-class RemoteInferenceChain(nn.Module):
-    """An auxiliary class that manages distributed inference in a chain of one or more remote transformer modules"""
-
-    def __init__(self, dht: DHT, config: DistributedBloomConfig, block_names: Sequence[str]):
-        super().__init__()
-        self.dht = dht
-        self.config, self.block_names = config, block_names
-        self.block_caches = {name: torch.zeros(1, MAX_LENGTH, config.hidden_size) for name in block_names}
-        self.current_position = 0
-
-    def step(self, hidden_states: torch.Tensor):
-        pass
-
-
-# plan:
-# - run inference STUB from a jupyter notebook
-# - extend to run actual inference
-# - extend to run multiple layers at a time

+ 16 - 61
src/client/remote_block.py

@@ -1,38 +1,45 @@
 from __future__ import annotations
 
 import asyncio
-from functools import partial
-from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Union
+import random
+from typing import Any, AsyncIterator, Dict, Optional
 
 import torch
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo as RemoteModuleInfo
-from hivemind.moe.expert_uid import ExpertUID
-from hivemind.p2p import P2P, PeerID, StubBase
+from hivemind.moe.client.expert import RemoteExpertWorker, RemoteExpert
+from hivemind.moe.expert_uid import ExpertInfo
+from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import DHTExpiration, MPFuture, anext, as_aiter, get_dht_time, nested_flatten
+from hivemind.utils import anext, nested_flatten
 
+from src.dht_utils import ModuleUID
+from src.data_structures import RemoteModuleInfo
 from src.server.handler import TransformerConnectionHandler
 
 
 class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
+    def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) #TODO replace this
+        super().__init__(peer_info, p2p)
+
     @property
     def stub(self) -> StubBase:
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
     def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
+        _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker
         return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
 
 
 class RemoteTransformerBlockInferenceSession:
     """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
 
-    def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+    def __init__(
+            self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator
+    ):
         self.uid, self.info = uid, info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
@@ -113,55 +120,3 @@ class RemoteTransformerBlockInferenceSession:
     def __exit__(self, *exc_details):
         self.close()
 
-
-def get_remote_module(
-    dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
-    """
-    :param uids: find experts with these ids from across the DHT
-    :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
-    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock if found else None]
-    """
-    assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
-    infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time), return_future
-    )
-
-    if return_future:
-
-        async def _unpack(infos_future: MPFuture, dht: DHT):
-            p2p = await dht.replicate_p2p()
-            return _create_remote_modules_from_infos(await infos_future, p2p)
-
-        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
-    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    return _create_remote_modules_from_infos(infos, p2p)
-
-
-async def _get_remote_module_infos(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
-) -> List[Optional[RemoteModuleInfo]]:
-    if expiration_time is None:
-        expiration_time = get_dht_time()
-    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
-    found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
-
-    experts: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
-    for i, uid in enumerate(uids):
-        server_peer_id = found[uid]
-        if server_peer_id is not None and isinstance(server_peer_id.value, str):
-            experts[i] = RemoteModuleInfo(uid, PeerID.from_base58(server_peer_id.value))
-    return experts
-
-
-def _create_remote_modules_from_infos(
-    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
-) -> List[Optional[RemoteTransformerBlock]]:
-    experts: List[Optional[RemoteTransformerBlock]] = []
-    for info in infos:
-        if info is not None:
-            experts.append(RemoteTransformerBlock(info, p2p))
-        else:
-            experts.append(None)
-    return experts

+ 8 - 0
src/data_structures.py

@@ -0,0 +1,8 @@
+from typing import NamedTuple, Collection
+
+from hivemind import PeerID
+
+
+ModuleUID = str
+UID_DELIMITER = '.'
+RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])

+ 118 - 0
src/dht_utils.py

@@ -0,0 +1,118 @@
+"""
+Utilities for declaring and retrieving active model layers using a shared DHT.
+"""
+from __future__ import annotations
+
+from functools import partial
+from typing import Dict, List, Optional, Sequence, Union
+
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, use_hivemind_log_handler, get_logger
+from hivemind.p2p import P2P, PeerID
+
+import src
+from src.data_structures import RemoteModuleInfo, ModuleUID, UID_DELIMITER
+
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+def declare_active_modules(
+    dht: DHT, uids: Sequence[ModuleUID], expiration_time: DHTExpiration, wait: bool = True
+) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
+    """
+    Declare that your node serves the specified modules; update timestamps if declared previously
+
+    :param uids: a list of module ids to declare
+    :param wait: if True, awaits for declaration to finish, otherwise runs in background
+    :param expiration_time: declated modules will be visible for this many seconds
+    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
+    """
+    if isinstance(uids, str):
+        uids = [uids]
+    if not isinstance(uids, list):
+        uids = list(uids)
+    for uid in uids:
+        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid
+    return dht.run_coroutine(
+        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time), return_future=not wait
+    )
+
+
+async def _declare_active_modules(
+    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: DHTExpiration
+) -> Dict[ModuleUID, bool]:
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    return await node.store_many(
+        keys=uids,
+        subkeys=[dht.peer_id.to_base58()] * len(uids),
+        values=[None] * len(uids),
+        expiration_time=expiration_time,
+        num_workers=num_workers
+    )
+
+
+def get_remote_module(
+    dht: DHT, uid_or_uids: Union[ModuleUID, List[ModuleUID]], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
+) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]:
+    """
+    :param uid_or_uids: find one or more modules with these ids from across the DHT
+    :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
+    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+    :returns: a list of [RemoteTransformerBlock if found else None]
+    """
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
+    )
+
+    if return_future:
+        async def _unpack(infos_future: MPFuture, dht: DHT):
+            p2p = await dht.replicate_p2p()
+            modules = _create_remote_modules_from_infos(await infos_future, p2p)
+            return modules[0] if single_uid else modules
+
+        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
+    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+    modules = _create_remote_modules_from_infos(infos, p2p)
+    return modules[0] if single_uid else modules
+
+
+async def _get_remote_module_infos(
+        dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
+) -> List[Optional[RemoteModuleInfo]]:
+    if expiration_time is None:
+        expiration_time = get_dht_time()
+    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+    found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+    modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
+    for i, uid in enumerate(uids):
+        metadata = found[uid]
+        if metadata is None or not isinstance(metadata.value, dict):
+            logger.error(f"Incorrect metadata for {uid}: {metadata}")
+            continue
+        valid_entries = set()
+        for maybe_peer_id, _unused_value in metadata.value.items():
+            try:
+                valid_entries.add(PeerID.from_base58(maybe_peer_id))
+            except:
+                logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
+        if valid_entries:
+            modules[i] = RemoteModuleInfo(uid, valid_entries)
+    return modules
+
+
+def _create_remote_modules_from_infos(
+    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
+) -> List[Optional[src.RemoteTransformerBlock]]:
+    modules: List[Optional[src.RemoteTransformerBlock]] = []
+    for info in infos:
+        if info is not None:
+            modules.append(src.RemoteTransformerBlock(info, p2p))
+        else:
+            modules.append(None)
+    return modules

+ 23 - 2
src/server/server.py

@@ -5,13 +5,14 @@ import threading
 from typing import Dict, Optional, Sequence, Union
 
 import torch
-from hivemind import DHT, BatchTensorDescriptor
+from hivemind import DHT, BatchTensorDescriptor, MAX_DHT_TIME_DISCREPANCY_SECONDS, get_dht_time
 from hivemind.moe.server.dht_handler import DHTHandlerThread
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+from src import declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
@@ -42,7 +43,7 @@ class Server(threading.Thread):
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
-        self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
+        self.dht_handler_thread = ModuleAnnouncerThread(self.module_backends, dht, update_period, expiration, daemon=True)
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
         if start:
@@ -212,3 +213,23 @@ class Server(threading.Thread):
 
         self.runtime.shutdown()
         logger.info("Server shutdown succesfully")
+
+
+class ModuleAnnouncerThread(threading.Thread):
+    """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
+    def __init__(
+        self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
+    ):
+        super().__init__(**kwargs)
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+        self.module_backends = module_backends
+        self.dht = dht
+        self.update_period = update_period
+        self.expiration = expiration
+        self.stop = threading.Event()
+
+    def run(self) -> None:
+        declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
+        while not self.stop.wait(self.update_period):
+            declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)

+ 3 - 2
tests/test_block_exact_match.py

@@ -4,7 +4,8 @@ import hivemind
 import torch
 
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock, get_remote_module
+from src.client.remote_block import RemoteTransformerBlock
+from src.dht_utils import get_remote_module
 
 INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
 if not INITIAL_PEERS:
@@ -22,7 +23,7 @@ REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
 
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    (remote_block,) = get_remote_module(dht, [BLOCK_UID])
+    remote_block = get_remote_module(dht, BLOCK_UID)
     assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
     assert isinstance(remote_block, RemoteTransformerBlock)