justheuristic 3 vuotta sitten
vanhempi
commit
dfc7449cd5
1 muutettua tiedostoa jossa 6 lisäystä ja 8 poistoa
  1. 6 8
      src/client/remote_sequential.py

+ 6 - 8
src/client/remote_sequential.py

@@ -5,7 +5,7 @@ import logging
 import random
 
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from torch import nn
@@ -99,12 +99,9 @@ class RemoteSequentialInferenceSession:
 
             # TODO begin throwaway prototype code
             remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
-            remote.info
             span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
             remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
-
             self.active_sessions.append(remote.inference_session())
-            print('BEGIN', current_block, remote, self.active_sessions[-1])
             self.stack.enter_context(self.active_sessions[-1])
             current_block = chosen_span.end
             # TODO end throwaway prototype code
@@ -117,13 +114,14 @@ class RemoteSequentialInferenceSession:
             outputs = session.step(inputs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
+        return inputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
-        assert not self.closed
-        self.stack.__exit__(*exc_details)
-        self.active_sessions.clear()
-        self.closed = True
+        if not self.closed:
+            self.stack.__exit__(*exc_details or (None, None, None))
+            self.active_sessions.clear()
+            self.closed = True
 
     def __exit__(self, *exc_details):
         self.close(*exc_details)