|
@@ -1,20 +1,18 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-import dataclasses
|
|
|
import logging
|
|
|
-import threading
|
|
|
-from functools import partial
|
|
|
-from typing import Optional, Tuple, NamedTuple, List, Sequence
|
|
|
+import random
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import DHT, get_logger, use_hivemind_log_handler, PeerID
|
|
|
+from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
-from hivemind.proto import runtime_pb2
|
|
|
from torch import nn
|
|
|
|
|
|
from src import DistributedBloomConfig
|
|
|
-from src.data_structures import UID_DELIMITER, RemoteModuleInfo, ModuleUID
|
|
|
-from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
|
|
|
+from src.client.remote_sequence_info import RemoteSequenceInfo
|
|
|
+from src.data_structures import UID_DELIMITER
|
|
|
+from src.dht_utils import _create_remote_modules_from_infos
|
|
|
+
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -43,7 +41,7 @@ class RemoteSequential(nn.Sequential):
|
|
|
|
|
|
self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
|
|
|
logger.debug(f"Remote block uids: {self.block_uids}")
|
|
|
- self.remote_model_info = RemoteModelInfo(dht, self.block_uids)
|
|
|
+ self.remote_model_info = RemoteSequenceInfo(dht, self.block_uids)
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
@@ -74,98 +72,32 @@ class RemoteSequential(nn.Sequential):
|
|
|
|
|
|
def inference_session(self) -> RemoteSequentialInferenceSession:
|
|
|
self.remote_model_info.update_()
|
|
|
- return RemoteExpertWorker.run_coroutine(RemoteSequentialInferenceSession._create(self))
|
|
|
+ return RemoteSequentialInferenceSession(self.remote_model_info)
|
|
|
|
|
|
|
|
|
-Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
|
|
|
|
|
|
|
|
|
-@dataclasses.dataclass(frozen=False, init=False)
|
|
|
-class RemoteModelInfo:
|
|
|
- """Stores meta-information about which peers host which blocks - and prepare to form sessions"""
|
|
|
- dht: DHT
|
|
|
- block_uids: Tuple[ModuleUID, ...]
|
|
|
- block_infos: List[Optional[RemoteModuleInfo], ...]
|
|
|
- spans_by_priority: List[Span] # sorted from best to worst
|
|
|
- spans_containing_block: Tuple[List[Span], ...]
|
|
|
- lock_changes: threading.Lock
|
|
|
+class RemoteSequentialInferenceSession:
|
|
|
+ """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
|
+
|
|
|
+ def __init__(self, remote_sequence_info: RemoteSequenceInfo):
|
|
|
+ self.remote_sequence_info = remote_sequence_info
|
|
|
+ self.closed = False
|
|
|
+
|
|
|
+ # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
|
|
+ current_final_block = 0
|
|
|
+ self.active_chain = []
|
|
|
+
|
|
|
+ while current_final_block != len(remote_sequence_info):
|
|
|
+ candidate_spans = remote_sequence_info.spans_containing_block[current_final_block]
|
|
|
+ chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
|
|
|
+ assert chosen_span.start <= current_final_block < chosen_span.end
|
|
|
+
|
|
|
+ self.active_chain.append((current_final_block, chosen_span.end, chosen_span))
|
|
|
+ current_final_block = chosen_span.end
|
|
|
|
|
|
- def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
|
|
|
- self.dht = dht
|
|
|
- self.block_uids = block_uids
|
|
|
- self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
|
|
|
- self.spans_by_priority = []
|
|
|
- self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
|
|
|
- self.lock_changes = threading.Lock()
|
|
|
- self.update_()
|
|
|
-
|
|
|
- for uid, info in zip(self.block_uids, self.block_infos):
|
|
|
- assert info is not None, f"Found no remote peers for block {uid}"
|
|
|
- assert self.spans_by_priority and self.spans_containing_block
|
|
|
-
|
|
|
- def update_(self):
|
|
|
- with self.lock_changes:
|
|
|
- self.update_block_infos_()
|
|
|
- self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
|
|
-
|
|
|
- def update_block_infos_(self):
|
|
|
- new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
|
|
|
- partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
|
|
|
- return_future=False)
|
|
|
- assert len(new_block_infos) == len(self.block_uids)
|
|
|
- for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
|
|
|
- if info is None:
|
|
|
- logger.warning(f"Found no block info for block {uid}")
|
|
|
- if not isinstance(info, RemoteModuleInfo):
|
|
|
- logger.warning(f"Unexpected dht entry type for {uid}: {info}")
|
|
|
- if not info.peer_ids:
|
|
|
- logger.warning(f"Found no active peers for block {uid}")
|
|
|
- if info.uid != uid:
|
|
|
- logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
|
|
|
- if not isinstance(info.peer_ids, set):
|
|
|
- logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
|
|
|
- self.block_infos[block_index] = info
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
|
|
|
- closed_spans = []
|
|
|
- active_spans = {}
|
|
|
- for block_index, info in enumerate(block_infos):
|
|
|
- for peer_id in info.peer_ids:
|
|
|
- if peer_id not in active_spans:
|
|
|
- active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
|
|
|
- else: # peer_id in active_spans
|
|
|
- active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
|
|
|
-
|
|
|
- for peer_id in list(active_spans.keys()):
|
|
|
- if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
|
|
|
- closed_spans.append(active_spans.pop(peer_id))
|
|
|
- assert not active_spans
|
|
|
-
|
|
|
- closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
|
|
|
-
|
|
|
- spans_containing_block = tuple(list() for _ in range(len(block_infos)))
|
|
|
- for span in closed_spans:
|
|
|
- for block_index in range(span.start, span.end):
|
|
|
- spans_containing_block[block_index].append(span)
|
|
|
-
|
|
|
- return closed_spans, spans_containing_block
|
|
|
|
|
|
|
|
|
-class RemoteSequentialInferenceSession:
|
|
|
- pass
|
|
|
-# """An interface to a multi-step *inference* session for a sequence of remote modules"""
|
|
|
-#
|
|
|
-# def __init__(self, block):
|
|
|
-# self.closed = False
|
|
|
-#
|
|
|
-# @classmethod
|
|
|
-# async def _create(cls, remote_sequential: RemoteSequential, **kwargs) -> RemoteSequentialInferenceSession:
|
|
|
-# """Create a new session for a sequence of modules. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
-#
|
|
|
-# remote_sequential.
|
|
|
-# return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
|
|
|
-#
|
|
|
# def step(self, new_hidden_states: torch.Tensor):
|
|
|
# """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
|
|
|
# if self.closed:
|