Jelajahi Sumber

inference session

justheuristic 3 tahun lalu
induk
melakukan
e7f716502c
1 mengubah file dengan 47 tambahan dan 67 penghapusan
  1. 47 67
      src/client/remote_sequential.py

+ 47 - 67
src/client/remote_sequential.py

@@ -1,14 +1,16 @@
 from __future__ import annotations
 
+import contextlib
 import logging
 import random
 
 import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertInfo
 from torch import nn
 
-from src import DistributedBloomConfig
+from src import DistributedBloomConfig, RemoteTransformerBlock
 from src.client.remote_sequence_info import RemoteSequenceInfo
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
@@ -78,74 +80,52 @@ class RemoteSequential(nn.Sequential):
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
-    def __init__(self, remote_sequence_info: RemoteSequenceInfo):
+    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
         self.remote_sequence_info = remote_sequence_info
+        self.p2p = p2p
         self.closed = False
+        self.stack = contextlib.ExitStack()
+        self.active_sessions = []
 
+    def __enter__(self):
+        assert not self.closed
+        self.stack.__enter__()
         # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        current_final_block = 0
-        self.active_chain = []
-
-        while current_final_block != len(remote_sequence_info):
-            candidate_spans = remote_sequence_info.spans_containing_block[current_final_block]
+        current_block = 0
+        while current_block != len(self.remote_sequence_info):
+            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
             chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
-            assert chosen_span.start <= current_final_block < chosen_span.end
-
-            self.active_chain.append((current_final_block, chosen_span.end, chosen_span))
-            current_final_block = chosen_span.end
-
-
-
-#     def step(self, new_hidden_states: torch.Tensor):
-#         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
-#         if self.closed:
-#             raise Exception("Session is closed, cannot perform step")
-#         # serialize inputs and put them into the queue
-#         inputs = (new_hidden_states,)
-#         outputs_serialized = RemoteExpertWorker.run_coroutine(
-#             self._step(
-#                 runtime_pb2.ExpertRequest(
-#                     uid=self.uid,
-#                     tensors=[
-#                         serialize_torch_tensor(tensor, proto.compression)
-#                         for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
-#                     ],
-#                 )
-#             )
-#         )
-#         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-#         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-#         return outputs[0]
-# 
-#     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
-#         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
-#         await self._inputs_queue.put(inputs_serialized)
-#         return await anext(self._outputs_stream)
-# 
-#     def close(self):
-#         """Finish a given inference session, close the underlying connection"""
-#         if self._outputs_stream is None:
-#             return  # already closed
-#         RemoteExpertWorker.run_coroutine(self._aclose_stream())
-#         self._outputs_stream = self._inputs_queue = None
-#         self.closed = True
-# 
-#     async def _aclose_stream(self):
-#         """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
-#         if self._outputs_stream is None:
-#             return  # already closed
-#         await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-#         try:
-#             await anext(self._outputs_stream)
-#         except StopAsyncIteration:
-#             pass
-# 
-#     def __del__(self):
-#         self.close()
-# 
-#     def __enter__(self):
-#         assert not self.closed
-#         return self
-# 
-#     def __exit__(self, *exc_details):
-#         self.close()
+            assert chosen_span.start <= current_block < chosen_span.end
+
+            # TODO begin throwaway prototype code
+            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
+            remote.info
+            span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
+            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
+
+            self.active_sessions.append(remote.inference_session())
+            print('BEGIN', current_block, remote, self.active_sessions[-1])
+            self.stack.enter_context(self.active_sessions[-1])
+            current_block = chosen_span.end
+            # TODO end throwaway prototype code
+
+        return self
+
+    def step(self, inputs: torch.Tensor):
+        assert not self.closed
+        for session in self.active_sessions:
+            outputs = session.step(inputs)
+            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+            inputs = outputs
+
+    def close(self, *exc_details):
+        """Finish a given inference session, close the underlying connection"""
+        assert not self.closed
+        self.active_sessions.clear()
+        self.closed = True
+
+    def __exit__(self, *exc_details):
+        self.close()
+
+    def __del__(self):
+        self.close()