|
@@ -109,8 +109,8 @@ class RemoteSequential(nn.Module):
|
|
class RemoteSequentialInferenceSession:
|
|
class RemoteSequentialInferenceSession:
|
|
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
|
|
|
|
- def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
|
|
|
|
- self.remote_sequence_info = remote_sequence_info
|
|
|
|
|
|
+ def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P):
|
|
|
|
+ self.sequence_manager = sequence_manager
|
|
self.p2p = p2p
|
|
self.p2p = p2p
|
|
self.closed = False
|
|
self.closed = False
|
|
self.stack = contextlib.ExitStack()
|
|
self.stack = contextlib.ExitStack()
|
|
@@ -121,15 +121,15 @@ class RemoteSequentialInferenceSession:
|
|
self.stack.__enter__()
|
|
self.stack.__enter__()
|
|
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
|
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
|
current_block = 0
|
|
current_block = 0
|
|
- while current_block != len(self.remote_sequence_info):
|
|
|
|
- candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
|
|
|
|
|
|
+ while current_block != len(self.sequence_manager):
|
|
|
|
+ candidate_spans = self.sequence_manager.spans_containing_block[current_block]
|
|
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
|
|
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
|
|
assert chosen_span.start <= current_block < chosen_span.end
|
|
assert chosen_span.start <= current_block < chosen_span.end
|
|
|
|
|
|
# TODO begin throwaway prototype code
|
|
# TODO begin throwaway prototype code
|
|
- remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
|
|
|
|
|
|
+ remote = RemoteTransformerBlock(self.sequence_manager.block_infos[current_block], self.p2p)
|
|
_ = remote.info # TODO fix
|
|
_ = remote.info # TODO fix
|
|
- span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
|
|
|
|
|
|
+ span_uids = self.sequence_manager.block_uids[current_block: chosen_span.end]
|
|
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
|
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
|
self.active_sessions.append(remote.inference_session())
|
|
self.active_sessions.append(remote.inference_session())
|
|
self.stack.enter_context(self.active_sessions[-1])
|
|
self.stack.enter_context(self.active_sessions[-1])
|