Bladeren bron

move block choice from InferenceSession to RemoteSequenceManager

justheuristic 3 jaren geleden
bovenliggende
commit
a652260793
4 gewijzigde bestanden met toevoegingen van 31 en 12 verwijderingen
  1. 4 6
      src/client/remote_block.py
  2. 2 5
      src/client/remote_sequential.py
  3. 21 0
      src/client/sequence_manager.py
  4. 4 1
      src/data_structures.py

+ 4 - 6
src/client/remote_block.py

@@ -39,9 +39,9 @@ class RemoteTransformerBlock(RemoteExpert):
 
     def inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
-        return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(
-            self.stub, self.uid, self.info
-        ))
+        return RemoteExpertWorker.run_coroutine(
+            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
+        )
 
     def begin_inference_session(self):
         logger.warning("beging_inference_session was renamed to just inference_session")
@@ -66,9 +66,7 @@ class RemoteTransformerBlockInferenceSession:
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
-        outputs_stream = await stub.rpc_inference(
-            cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout
-        )
+        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
         return cls(uid, rpc_info, inputs_queue, outputs_stream)
 
     @staticmethod

+ 2 - 5
src/client/remote_sequential.py

@@ -120,11 +120,8 @@ class RemoteSequentialInferenceSession:
         assert not self.closed
         self.stack.__enter__()
         # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        current_block = 0
-        while current_block != len(self.sequence_manager):
-            candidate_spans = self.sequence_manager.spans_containing_block[current_block]
-            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
-            assert chosen_span.start <= current_block < chosen_span.end
+
+        for chosen_span in self.sequence_manager.make_sequence():
 
             # TODO begin throwaway prototype code
             remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)

+ 21 - 0
src/client/sequence_manager.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import random
 import threading
 from typing import List, Optional, Sequence, Tuple, Union
 
@@ -38,6 +39,26 @@ class RemoteSequenceManager:
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
 
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
+        """
+        Form a sequence of remote servers that collectively serve all consecutive layers
+
+        :param start_index: optional index of the first module in a sequence, default = the first of block_uids
+        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        """
+        end_index = end_index if end_index is not None else len(self.block_uids)
+        span_sequence = []
+        current_index = start_index
+        while current_index != end_index - 1:
+            candidate_spans = self.spans_containing_block[current_index]
+
+            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+
+            assert chosen_span.start <= current_index < chosen_span.end
+            span_sequence.append(chosen_span)
+
+        return span_sequence
+
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
         """Get a RemoteSequenceManager for a sub-sequence of blocks"""
         assert isinstance(ix, (int, slice))

+ 4 - 1
src/data_structures.py

@@ -1,6 +1,6 @@
 from dataclasses import dataclass
 from enum import Enum
-from typing import Dict
+from typing import Any, Dict
 
 from hivemind import PeerID
 
@@ -36,3 +36,6 @@ class RemoteSpanInfo:
     start: int
     end: int
     peer_id: PeerID
+
+
+RPCInfo = Dict[str, Any]