justheuristic 3 年之前
父節點
當前提交
ec477b910b
共有 2 個文件被更改,包括 119 次插入94 次删除
  1. 93 0
      src/client/remote_sequence_info.py
  2. 26 94
      src/client/remote_sequential.py

+ 93 - 0
src/client/remote_sequence_info.py

@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+import dataclasses
+import threading
+from functools import partial
+from typing import Tuple, List, Optional, Sequence, NamedTuple
+
+from hivemind import DHT, PeerID
+from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+
+from src.data_structures import ModuleUID, RemoteModuleInfo
+from src.dht_utils import _get_remote_module_infos
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
+
+
+@dataclasses.dataclass(frozen=False, init=False)
+class RemoteSequenceInfo:
+    """Keeps and updates the meta-information about which peers host which blocks"""
+    dht: DHT
+    block_uids: List[ModuleUID, ...]
+    block_infos: List[Optional[RemoteModuleInfo], ...]
+    spans_by_priority: List[Span]  # sorted from best to worst
+    spans_containing_block: Tuple[List[Span], ...]
+    lock_changes: threading.Lock
+
+    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
+        self.dht = dht
+        self.block_uids = list(block_uids)
+        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.spans_by_priority = []
+        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
+        self.lock_changes = threading.Lock()
+        self.update_()
+
+        for uid, info in zip(self.block_uids, self.block_infos):
+            assert info is not None, f"Found no remote peers for block {uid}"
+        assert self.spans_by_priority and self.spans_containing_block
+
+    def update_(self):
+        with self.lock_changes:
+            self.update_block_infos_()
+            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+
+    def update_block_infos_(self):
+        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
+            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
+            return_future=False)
+        assert len(new_block_infos) == len(self.block_uids)
+        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+            if info is None:
+                logger.warning(f"Found no block info for block {uid}")
+            if not isinstance(info, RemoteModuleInfo):
+                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+            if not info.peer_ids:
+                logger.warning(f"Found no active peers for block {uid}")
+            if info.uid != uid:
+                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+            if not isinstance(info.peer_ids, set):
+                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
+            self.block_infos[block_index] = info
+
+    @staticmethod
+    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
+        closed_spans = []
+        active_spans = {}
+        for block_index, info in enumerate(block_infos):
+            for peer_id in info.peer_ids:
+                if peer_id not in active_spans:
+                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
+                else:  # peer_id in active_spans
+                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
+
+            for peer_id in list(active_spans.keys()):
+                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                    closed_spans.append(active_spans.pop(peer_id))
+        assert not active_spans
+
+        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+
+        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
+        for span in closed_spans:
+            for block_index in range(span.start, span.end):
+                spans_containing_block[block_index].append(span)
+
+        return closed_spans, spans_containing_block
+
+    def __len__(self):
+        return len(self.block_uids)

+ 26 - 94
src/client/remote_sequential.py

@@ -1,20 +1,18 @@
 from __future__ import annotations
 
-import dataclasses
 import logging
-import threading
-from functools import partial
-from typing import Optional, Tuple, NamedTuple, List, Sequence
+import random
 
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler, PeerID
+from hivemind import DHT, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.proto import runtime_pb2
 from torch import nn
 
 from src import DistributedBloomConfig
-from src.data_structures import UID_DELIMITER, RemoteModuleInfo, ModuleUID
-from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
+from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import _create_remote_modules_from_infos
+
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -43,7 +41,7 @@ class RemoteSequential(nn.Sequential):
 
         self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
         logger.debug(f"Remote block uids: {self.block_uids}")
-        self.remote_model_info = RemoteModelInfo(dht, self.block_uids)
+        self.remote_model_info = RemoteSequenceInfo(dht, self.block_uids)
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -74,98 +72,32 @@ class RemoteSequential(nn.Sequential):
 
     def inference_session(self) -> RemoteSequentialInferenceSession:
         self.remote_model_info.update_()
-        return RemoteExpertWorker.run_coroutine(RemoteSequentialInferenceSession._create(self))
+        return RemoteSequentialInferenceSession(self.remote_model_info)
 
 
-Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
 
 
-@dataclasses.dataclass(frozen=False, init=False)
-class RemoteModelInfo:
-    """Stores meta-information about which peers host which blocks - and prepare to form sessions"""
-    dht: DHT
-    block_uids: Tuple[ModuleUID, ...]
-    block_infos: List[Optional[RemoteModuleInfo], ...]
-    spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span], ...]
-    lock_changes: threading.Lock
+class RemoteSequentialInferenceSession:
+    """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
+
+    def __init__(self, remote_sequence_info: RemoteSequenceInfo):
+        self.remote_sequence_info = remote_sequence_info
+        self.closed = False
+
+        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
+        current_final_block = 0
+        self.active_chain = []
+
+        while current_final_block != len(remote_sequence_info):
+            candidate_spans = remote_sequence_info.spans_containing_block[current_final_block]
+            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
+            assert chosen_span.start <= current_final_block < chosen_span.end
+
+            self.active_chain.append((current_final_block, chosen_span.end, chosen_span))
+            current_final_block = chosen_span.end
 
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
-        self.dht = dht
-        self.block_uids = block_uids
-        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
-        self.spans_by_priority = []
-        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
-        self.lock_changes = threading.Lock()
-        self.update_()
-
-        for uid, info in zip(self.block_uids, self.block_infos):
-            assert info is not None, f"Found no remote peers for block {uid}"
-        assert self.spans_by_priority and self.spans_containing_block
-
-    def update_(self):
-        with self.lock_changes:
-            self.update_block_infos_()
-            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
-
-    def update_block_infos_(self):
-        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
-            return_future=False)
-        assert len(new_block_infos) == len(self.block_uids)
-        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.warning(f"Found no block info for block {uid}")
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.peer_ids:
-                logger.warning(f"Found no active peers for block {uid}")
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            if not isinstance(info.peer_ids, set):
-                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
-            self.block_infos[block_index] = info
-
-    @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            for peer_id in info.peer_ids:
-                if peer_id not in active_spans:
-                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
-                else:  # peer_id in active_spans
-                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
-
-            for peer_id in list(active_spans.keys()):
-                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans
-
-        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
-
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
-            for block_index in range(span.start, span.end):
-                spans_containing_block[block_index].append(span)
-
-        return closed_spans, spans_containing_block
 
 
-class RemoteSequentialInferenceSession:
-    pass
-#     """An interface to a multi-step *inference* session for a sequence of remote modules"""
-# 
-#     def __init__(self, block):
-#         self.closed = False
-# 
-#     @classmethod
-#     async def _create(cls, remote_sequential: RemoteSequential, **kwargs) -> RemoteSequentialInferenceSession:
-#         """Create a new session for a sequence of modules. This code is meant to be run inside RemoteExpertWorker"""
-# 
-#         remote_sequential.
-#         return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
-# 
 #     def step(self, new_hidden_states: torch.Tensor):
 #         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
 #         if self.closed: