Browse Source

sequence_info -> sequence_manager

justheuristic 3 years ago
parent
commit
1d72f1890e
1 changed files with 6 additions and 6 deletions
  1. 6 6
      src/client/remote_sequential.py

+ 6 - 6
src/client/remote_sequential.py

@@ -109,8 +109,8 @@ class RemoteSequential(nn.Module):
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
-    def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
-        self.remote_sequence_info = remote_sequence_info
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P):
+        self.sequence_manager = sequence_manager
         self.p2p = p2p
         self.closed = False
         self.stack = contextlib.ExitStack()
@@ -121,15 +121,15 @@ class RemoteSequentialInferenceSession:
         self.stack.__enter__()
         # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
         current_block = 0
-        while current_block != len(self.remote_sequence_info):
-            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
+        while current_block != len(self.sequence_manager):
+            candidate_spans = self.sequence_manager.spans_containing_block[current_block]
             chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
             assert chosen_span.start <= current_block < chosen_span.end
 
             # TODO begin throwaway prototype code
-            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
+            remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)
             _ = remote.info  # TODO fix
-            span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
+            span_uids = self.sequence_manager.block_uids[current_block: chosen_span.end]
             remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
             self.active_sessions.append(remote.inference_session())
             self.stack.enter_context(self.active_sessions[-1])