فهرست منبع

prototype remote sequential

justheuristic 3 سال پیش
والد
کامیت
5849cea28c
2فایلهای تغییر یافته به همراه94 افزوده شده و 0 حذف شده
  1. 28 0
      src/client/remote_model.py
  2. 66 0
      src/client/remote_sequential.py

+ 28 - 0
src/client/remote_model.py

@@ -0,0 +1,28 @@
+# this code is in active development, interfaces may change
+
+from hivemind import DHT, use_hivemind_log_handler, get_logger
+
+from src.bloom import DistributedBloomConfig, BloomForCausalLM
+from src.client.remote_sequential import RemoteSequential
+from src.data_structures import UID_DELIMITER
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class DistributedBloomForCausalLM(BloomForCausalLM):
+    """BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
+        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
+        if prefix.endswith(UID_DELIMITER):
+            logger.warning(f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
+                           f"This will cause {self.__class__.__name__} to look for modules under "
+                           f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended.")
+
+        n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
+        super().__init__(config)
+        assert len(self.transformer.h) == 0
+        config.n_layer = n_layer
+        self.transformer.h = RemoteSequential(config, dht, prefix)
+
+

+ 66 - 0
src/client/remote_sequential.py

@@ -0,0 +1,66 @@
+import logging
+from functools import partial
+from typing import Tuple
+
+import torch
+from hivemind import DHT
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from torch import nn
+
+from src import DistributedBloomConfig
+from src.client.remote_model import logger
+from src.data_structures import UID_DELIMITER, RemoteModuleInfo
+from src.dht_utils import _get_remote_module_infos, _create_remote_modules_from_infos
+
+
+class RemoteSequential(nn.Sequential):
+    """A sequence of transformer blocks hosted by the swarm"""
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
+        super().__init__()
+        self.config = config
+        self.dht = dht
+        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+
+        self.prefix = prefix
+        self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
+        logger.debug(f"Remote block uids: {self.block_uids}")
+        self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(dht.run_coroutine(
+            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float('inf')), return_future=False
+        ))
+
+        self.max_retries = max_retries
+
+        assert len(self.block_infos) == len(self.block_uids)
+        for uid, info in zip(self.block_uids, self.block_infos):
+            assert isinstance(info, (type(None), RemoteModuleInfo)), f"Unexpected dht entry for {uid}: {info}"
+            assert info is not None, f"Found no active peers for block {uid}"
+            assert isinstance(info.peer_ids, (list, tuple)), f"expected peer_ids to be list/tuple, got {info.peer_ids}"
+            assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
+            assert len(info.peer_ids) > 0,  f"Found no active peers for block {uid}"
+
+    def forward(self, inputs: torch.Tensor):
+        assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
+        for block_index in range(self.config.n_layer):
+            for retry_index in range(self.max_retries):
+                try:
+                    block = self[block_index]
+                    outputs, = block(inputs)
+                    assert isinstance(outputs, torch.Tensor)
+                    assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
+                    inputs = outputs
+                    break
+                except Exception as e:
+                    if retry_index == self.max_retries - 1:
+                        raise e
+                    else:
+                        logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
+        return inputs
+
+    def __getitem__(self, block_index: int):
+        assert 0 <= block_index < self.config.n_layer
+        module, = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
+        return module
+
+    def __iter__(self):
+        for block_index in range(self.config.n_layer):
+            yield self[block_index]