瀏覽代碼

Make inference session fields private

Aleksandr Borzunov 2 年之前
父節點
當前提交
b6316a5603
共有 1 個文件被更改,包括 42 次插入42 次删除
  1. 42 42
      src/client/inference_session.py

+ 42 - 42
src/client/inference_session.py

@@ -167,24 +167,24 @@ class RemoteSequentialInferenceSession:
     """
 
     def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, **metadata):
-        self.sequence_manager = sequence_manager
-        self.p2p = p2p
-        self.closed = False
-        self.chosen_spans = []
-        self.server_sessions = []
-        self.server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
-        self.position = 0
-        self.metadata = metadata
+        self._sequence_manager = sequence_manager
+        self._p2p = p2p
+        self._closed = False
+        self._chosen_spans = []
+        self._server_sessions = []
+        self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
+        self._position = 0
+        self._metadata = metadata
 
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[RemoteServerInferenceSession]:
         server_sessions = []
         for span in chosen_spans:
-            stub = TransformerConnectionHandler.get_stub(self.p2p, span.peer_id)
-            span_uids = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[span.start : span.end])
+            stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
+            span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
             session = RemoteExpertWorker.run_coroutine(
                 RemoteServerInferenceSession.create(
-                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.sequence_manager.timeout,
-                    **self.metadata
+                    stub, span_uids, rpc_info=self._sequence_manager.rpc_info, timeout=self._sequence_manager.timeout,
+                    **self._metadata
                 )
             )
             server_sessions.append(session)
@@ -200,54 +200,54 @@ class RemoteSequentialInferenceSession:
                 logger.log(exc_loglevel, "Caught exception while closing connection to server:", exc_info=True)
 
     def __enter__(self):
-        assert not self.closed and not self.chosen_spans
+        assert not self._closed and not self._chosen_spans
         return self
 
     def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
-        assert not self.closed
+        assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
         if prompts is None or is_dummy(prompts):
             prompts = DUMMY
         else:
-            assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
+            assert prompts.ndim == 4 and prompts.shape[0] == len(self._sequence_manager)
         n_input_tokens = inputs.shape[1]
 
         server_idx = 0
         block_idx = 0
         recovery_mode = False
-        while block_idx < len(self.sequence_manager):
+        while block_idx < len(self._sequence_manager):
             for attempt_no in itertools.count():
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 try:
-                    if not self.chosen_spans or not self.server_sessions or attempt_no >= 1:
+                    if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
                         if attempt_no >= 1:
-                            self._exit_server_sessions(self.server_sessions[server_idx:], verbose=False)
-                            self.server_sessions[server_idx:] = []
-                            self.chosen_spans[server_idx:] = []
-                            self.server_inputs[server_idx + 1:] = []
+                            self._exit_server_sessions(self._server_sessions[server_idx:], verbose=False)
+                            self._server_sessions[server_idx:] = []
+                            self._chosen_spans[server_idx:] = []
+                            self._server_inputs[server_idx + 1:] = []
 
-                            self.sequence_manager.update_()
+                            self._sequence_manager.update_()
                             recovery_mode = True
                             if attempt_no == 1:
                                 logger.info("Entering recovery mode, remote attention caches will be regenerated")
 
-                        backup_spans = self.sequence_manager.make_sequence(block_idx)
-                        self.chosen_spans.extend(backup_spans)
-                        self.server_sessions.extend(self._enter_server_sessions(backup_spans))
+                        backup_spans = self._sequence_manager.make_sequence(block_idx)
+                        self._chosen_spans.extend(backup_spans)
+                        self._server_sessions.extend(self._enter_server_sessions(backup_spans))
                         logger.debug(f"Found path from block {block_idx} via {len(backup_spans)} servers")
 
-                    session = self.server_sessions[server_idx]
-                    span = self.chosen_spans[server_idx]
+                    session = self._server_sessions[server_idx]
+                    span = self._chosen_spans[server_idx]
 
-                    if server_idx == len(self.server_inputs):
-                        self.server_inputs.append(inputs)
-                    elif self.server_inputs[server_idx].shape[1] == self.position:
-                        self.server_inputs[server_idx] = torch.cat([self.server_inputs[server_idx], inputs], dim=1)
-                    assert self.server_inputs[server_idx].shape[1] == self.position + n_input_tokens
+                    if server_idx == len(self._server_inputs):
+                        self._server_inputs.append(inputs)
+                    elif self._server_inputs[server_idx].shape[1] == self._position:
+                        self._server_inputs[server_idx] = torch.cat([self._server_inputs[server_idx], inputs], dim=1)
+                    assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens
 
                     if recovery_mode:
-                        inputs = self.server_inputs[server_idx]  # Take full inputs including prefix
+                        inputs = self._server_inputs[server_idx]  # Take full inputs including prefix
                     outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
                     assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
                     inputs = outputs
@@ -256,25 +256,25 @@ class RemoteSequentialInferenceSession:
                     block_idx = span.end
                     break
                 except Exception as e:
-                    delay = self.sequence_manager.min_backoff * 2**attempt_no
+                    delay = self._sequence_manager.min_backoff * 2**attempt_no
                     logger.warning(
                         f"Caught exception when running inference from block {block_idx} "
-                        f"(retry in {delay:.1f} sec): {repr(e)}"
+                        f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
                     logger.debug("See detailed traceback below:", exc_info=True)
                     time.sleep(delay)
 
-        self.position += n_input_tokens
+        self._position += n_input_tokens
         return inputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
-        if not self.closed:
-            self.server_inputs.clear()
-            self._exit_server_sessions(self.server_sessions, verbose=True)
-            self.server_sessions.clear()
-            self.chosen_spans.clear()
-            self.closed = True
+        if not self._closed:
+            self._server_inputs.clear()
+            self._exit_server_sessions(self._server_sessions, verbose=True)
+            self._server_sessions.clear()
+            self._chosen_spans.clear()
+            self._closed = True
 
     def __exit__(self, *exc_details):
         self.close(*exc_details)