Browse Source

implement rpc_info in sequence manager

justheuristic 3 years ago
parent
commit
f670aaf11b
3 changed files with 55 additions and 38 deletions
  1. 0 1
      src/client/remote_block.py
  2. 14 18
      src/client/remote_sequential.py
  3. 41 19
      src/client/sequence_manager.py

+ 0 - 1
src/client/remote_block.py

@@ -8,7 +8,6 @@ from typing import AsyncIterator, Optional
 import torch
 from hivemind import serialize_torch_tensor, nested_flatten, deserialize_torch_tensor, anext
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2

+ 14 - 18
src/client/remote_sequential.py

@@ -14,8 +14,9 @@ from torch import nn
 import src
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import UID_DELIMITER, RemoteSpanInfo
+from src.data_structures import UID_DELIMITER, RemoteSpanInfo, CHAIN_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
+from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -31,7 +32,6 @@ class RemoteSequential(nn.Module):
         config: src.DistributedBloomConfig,
         dht: DHT,
         prefix: str,
-        max_retries: int = 3,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
     ):
@@ -47,13 +47,12 @@ class RemoteSequential(nn.Module):
         self.config = config
         self.dht = dht
         self.prefix = prefix
-        self.max_retries = max_retries
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
         block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids)
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
             self.is_subsequence = False
         else:
             assert isinstance(sequence_manager.block_uids, list)
@@ -63,7 +62,7 @@ class RemoteSequential(nn.Module):
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
         for block_index in range(self.config.n_layer):
-            for retry_index in range(self.max_retries):
+            for retry_index in range(self.sequence_manager.max_retries):
                 try:
                     block = self[block_index]
                     (outputs,) = block(inputs)
@@ -72,7 +71,7 @@ class RemoteSequential(nn.Module):
                     inputs = outputs
                     break
                 except Exception as e:
-                    if retry_index == self.max_retries - 1:
+                    if retry_index == self.sequence_manager.max_retries - 1:
                         raise e
                     else:
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
@@ -109,13 +108,14 @@ class RemoteSequential(nn.Module):
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float]=None):
         self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
         self.chosen_spans: List[RemoteSpanInfo] = []
         self.stack = contextlib.ExitStack()
         self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+        self.timeout = timeout
 
     def __enter__(self):
         assert not self.closed and not self.chosen_spans
@@ -124,17 +124,13 @@ class RemoteSequentialInferenceSession:
         self.chosen_spans.extend(self.sequence_manager.make_sequence())
 
         for chosen_span in self.chosen_spans:
-            TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
-
-            # TODO begin throwaway prototype code
-            remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)
-            _ = remote.info  # TODO fix
-            span_uids = self.sequence_manager.block_uids[current_block : chosen_span.end]
-            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
-            self.inference_sessions.append(remote.inference_session())
-            self.stack.enter_context(self.inference_sessions[-1])
-            current_block = chosen_span.end
-            # TODO end throwaway prototype code
+            stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
+            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start: chosen_span.end])
+            inference_session = RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(
+                stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
+            ))
+            self.inference_sessions.append(inference_session)
+            self.stack.enter_context(inference_session)
 
         return self
 

+ 41 - 19
src/client/sequence_manager.py

@@ -4,34 +4,34 @@ import random
 import threading
 from typing import List, Optional, Sequence, Tuple, Union
 
-from hivemind import DHT, DHTExpiration
+from hivemind import DHT, DHTExpiration, P2P, MSGPackSerializer
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
+from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
 class RemoteSequenceManager:
-    """Keeps and updates the meta-information about which peers host which blocks"""
-
-    dht: DHT
-    block_uids: List[ModuleUID]
-    block_infos: List[Optional[RemoteModuleInfo]]
-    spans_by_priority: List[RemoteSpanInfo]  # sorted from best to worst
-    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
-    last_update_time: DHTExpiration
-    lock_changes: threading.Lock
-
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
-        self.dht = dht
-        self.block_uids = list(block_uids)
-        self.block_infos = [None] * len(self.block_uids)
-        self.spans_by_priority = []
-        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
-        self.last_update_time = -float("inf")
+    """
+    Keeps and updates the meta-information about which peers host which blocks.
+    In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
+    """
+
+    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
+        self.dht, self.p2p = dht, p2p
+        self.block_uids: List[ModuleUID] = list(block_uids)
+        self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
+        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
+        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+        self.last_update_time: DHTExpiration = -float("inf")
+        self.max_retries = max_retries
+        self._rpc_info = None
         self.lock_changes = threading.Lock()
         self.update_()
 
@@ -65,7 +65,7 @@ class RemoteSequenceManager:
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
         with self.lock_changes:
-            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
+            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
             subseq.block_infos = self.block_infos[ix]
             subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
             subseq.last_update_time = self.last_update_time
@@ -123,3 +123,25 @@ class RemoteSequenceManager:
 
     def __len__(self):
         return len(self.block_uids)
+
+    @property
+    def rpc_info(self):
+        """Return the rpc_info queried from one of the servers that hold the first block"""
+        if self._rpc_info is None:
+            retries = 0
+            for i in range(self.max_retries):
+                try:
+                    self.update_()
+                    peer_id = random.choice(list(self.block_infos[0].servers.keys()))
+                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
+                    outputs = RemoteExpertWorker.run_coroutine(
+                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
+                    )
+                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                except Exception as e:
+                    retries += 1
+                    if retries >= self.max_retries:
+                        raise e
+                    else:
+                        logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
+        return self._rpc_info