|
@@ -5,7 +5,7 @@ import logging
|
|
|
import random
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
+from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.moe.expert_uid import ExpertInfo
|
|
|
from torch import nn
|
|
@@ -99,12 +99,9 @@ class RemoteSequentialInferenceSession:
|
|
|
|
|
|
# TODO begin throwaway prototype code
|
|
|
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
|
|
|
- remote.info
|
|
|
span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
|
|
|
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
|
|
-
|
|
|
self.active_sessions.append(remote.inference_session())
|
|
|
- print('BEGIN', current_block, remote, self.active_sessions[-1])
|
|
|
self.stack.enter_context(self.active_sessions[-1])
|
|
|
current_block = chosen_span.end
|
|
|
# TODO end throwaway prototype code
|
|
@@ -117,13 +114,14 @@ class RemoteSequentialInferenceSession:
|
|
|
outputs = session.step(inputs)
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
inputs = outputs
|
|
|
+ return inputs
|
|
|
|
|
|
def close(self, *exc_details):
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
- assert not self.closed
|
|
|
- self.stack.__exit__(*exc_details)
|
|
|
- self.active_sessions.clear()
|
|
|
- self.closed = True
|
|
|
+ if not self.closed:
|
|
|
+ self.stack.__exit__(*exc_details or (None, None, None))
|
|
|
+ self.active_sessions.clear()
|
|
|
+ self.closed = True
|
|
|
|
|
|
def __exit__(self, *exc_details):
|
|
|
self.close(*exc_details)
|