Browse Source

refactor, add swarm info

Aleksandr Borzunov 3 năm trước cách đây
mục cha
commit
b78d713347

+ 1 - 1
README.md

@@ -65,7 +65,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer3.begin_inference_session() as sess:
+with layer3.inference_session() as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```

+ 10 - 2
src/client/remote_block.py

@@ -11,13 +11,17 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, nested_flatten
+from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
 
 from src.data_structures import RemoteModuleInfo
 from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 
 
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
 class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
@@ -34,11 +38,15 @@ class RemoteTransformerBlock(RemoteExpert):
             assert v is None, f"Extra keyword arguments are not yet supported (got {k} = {v})"
         return super().forward(inputs)
 
-    def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
+    def inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
         _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker
         return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
 
+    def begin_inference_session(self):
+        logger.warning("beging_inference_session was renamed to just inference_session")
+        return self.inference_session()
+
 
 class RemoteTransformerBlockInferenceSession:
     """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""

+ 158 - 20
src/client/remote_sequential.py

@@ -1,14 +1,19 @@
+from __future__ import annotations
+
+import dataclasses
 import logging
+import threading
 from functools import partial
-from typing import Optional, Tuple
+from typing import Optional, Tuple, NamedTuple, List, Sequence
 
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import DHT, get_logger, use_hivemind_log_handler, PeerID
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
 from torch import nn
 
 from src import DistributedBloomConfig
-from src.data_structures import UID_DELIMITER, RemoteModuleInfo
+from src.data_structures import UID_DELIMITER, RemoteModuleInfo, ModuleUID
 from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
 
 use_hivemind_log_handler("in_root_logger")
@@ -32,27 +37,13 @@ class RemoteSequential(nn.Sequential):
         super().__init__()
         self.config = config
         self.dht = dht
+        self.prefix = prefix
+        self.max_retries = max_retries
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
 
-        self.prefix = prefix
         self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
         logger.debug(f"Remote block uids: {self.block_uids}")
-        self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(
-            dht.run_coroutine(
-                partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
-                return_future=False,
-            )
-        )
-
-        self.max_retries = max_retries
-
-        assert len(self.block_infos) == len(self.block_uids)
-        for uid, info in zip(self.block_uids, self.block_infos):
-            assert isinstance(info, (type(None), RemoteModuleInfo)), f"Unexpected dht entry for {uid}: {info}"
-            assert info is not None, f"Found no active peers for block {uid}"
-            assert isinstance(info.peer_ids, set), f"expected peer_ids to be a set, got {info.peer_ids}"
-            assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
-            assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
+        self.remote_model_info = RemoteModelInfo(dht, self.block_uids)
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -80,3 +71,150 @@ class RemoteSequential(nn.Sequential):
     def __iter__(self):
         for block_index in range(self.config.n_layer):
             yield self[block_index]
+
+    def inference_session(self) -> RemoteSequentialInferenceSession:
+        self.remote_model_info.update_()
+        return RemoteExpertWorker.run_coroutine(RemoteSequentialInferenceSession._create(self))
+
+
+Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
+
+
+@dataclasses.dataclass(frozen=False, init=False)
+class RemoteModelInfo:
+    """Stores meta-information about which peers host which blocks - and prepare to form sessions"""
+    dht: DHT
+    block_uids: Tuple[ModuleUID, ...]
+    block_infos: List[Optional[RemoteModuleInfo], ...]
+    spans_by_priority: List[Span]  # sorted from best to worst
+    spans_containing_block: Tuple[List[Span], ...]
+    lock_changes: threading.Lock
+
+    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
+        self.dht = dht
+        self.block_uids = block_uids
+        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.spans_by_priority = []
+        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
+        self.lock_changes = threading.Lock()
+        self.update_()
+
+        for uid, info in zip(self.block_uids, self.block_infos):
+            assert info is not None, f"Found no remote peers for block {uid}"
+        assert self.spans_by_priority and self.spans_containing_block
+
+    def update_(self):
+        with self.lock_changes:
+            self.update_block_infos_()
+            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+
+    def update_block_infos_(self):
+        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
+            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
+            return_future=False)
+        assert len(new_block_infos) == len(self.block_uids)
+        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+            if info is None:
+                logger.warning(f"Found no block info for block {uid}")
+            if not isinstance(info, RemoteModuleInfo):
+                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+            if not info.peer_ids:
+                logger.warning(f"Found no active peers for block {uid}")
+            if info.uid != uid:
+                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+            if not isinstance(info.peer_ids, set):
+                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
+            self.block_infos[block_index] = info
+
+    @staticmethod
+    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
+        closed_spans = []
+        active_spans = {}
+        for block_index, info in enumerate(block_infos):
+            for peer_id in info.peer_ids:
+                if peer_id not in active_spans:
+                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
+                else:  # peer_id in active_spans
+                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
+
+            for peer_id in list(active_spans.keys()):
+                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                    closed_spans.append(active_spans.pop(peer_id))
+        assert not active_spans
+
+        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+
+        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
+        for span in closed_spans:
+            for block_index in range(span.start, span.end):
+                spans_containing_block[block_index].append(span)
+
+        return closed_spans, spans_containing_block
+
+# 
+# class RemoteSequentialInferenceSession:
+#     """An interface to a multi-step *inference* session for a sequence of remote modules"""
+# 
+#     def __init__(self, block):
+#         self.closed = False
+# 
+#     @classmethod
+#     async def _create(cls, remote_sequential: RemoteSequential, **kwargs) -> RemoteSequentialInferenceSession:
+#         """Create a new session for a sequence of modules. This code is meant to be run inside RemoteExpertWorker"""
+# 
+#         remote_sequential.
+#         return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
+# 
+#     def step(self, new_hidden_states: torch.Tensor):
+#         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+#         if self.closed:
+#             raise Exception("Session is closed, cannot perform step")
+#         # serialize inputs and put them into the queue
+#         inputs = (new_hidden_states,)
+#         outputs_serialized = RemoteExpertWorker.run_coroutine(
+#             self._step(
+#                 runtime_pb2.ExpertRequest(
+#                     uid=self.uid,
+#                     tensors=[
+#                         serialize_torch_tensor(tensor, proto.compression)
+#                         for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
+#                     ],
+#                 )
+#             )
+#         )
+#         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+#         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+#         return outputs[0]
+# 
+#     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+#         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+#         await self._inputs_queue.put(inputs_serialized)
+#         return await anext(self._outputs_stream)
+# 
+#     def close(self):
+#         """Finish a given inference session, close the underlying connection"""
+#         if self._outputs_stream is None:
+#             return  # already closed
+#         RemoteExpertWorker.run_coroutine(self._aclose_stream())
+#         self._outputs_stream = self._inputs_queue = None
+#         self.closed = True
+# 
+#     async def _aclose_stream(self):
+#         """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+#         if self._outputs_stream is None:
+#             return  # already closed
+#         await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+#         try:
+#             await anext(self._outputs_stream)
+#         except StopAsyncIteration:
+#             pass
+# 
+#     def __del__(self):
+#         self.close()
+# 
+#     def __enter__(self):
+#         assert not self.closed
+#         return self
+# 
+#     def __exit__(self, *exc_details):
+#         self.close()

+ 1 - 1
tests/test_block_exact_match.py

@@ -32,7 +32,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     (outputs_forward,) = remote_block(inputs)
 
     outputs_inference = []
-    with remote_block.begin_inference_session() as sess:
+    with remote_block.inference_session() as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)

+ 1 - 1
tests/test_chained_inference.py

@@ -39,7 +39,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
     inputs = torch.randn(1, 8, 4096)
 
     outputs_inference = []
-    with remote_block.begin_inference_session() as sess:
+    with remote_block.inference_session() as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)