Browse Source

client-side prompts

justheuristic 3 years ago
parent
commit
33aa952d41
1 changed files with 29 additions and 5 deletions
  1. 29 5
      src/client/inference_session.py

+ 29 - 5
src/client/inference_session.py

@@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
         max_length: int,
     ):
         self.uid, self.rpc_info = uid, rpc_info
+        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
@@ -69,12 +71,30 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor, prompts: Optional[torch.Tensor] = None):
-        """Inference step: send a chunk of input tesors and receive a chunk of outputs"""
+    def step(self,
+             new_hidden_states: torch.Tensor,
+             prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None):
+        """
+        Inference step: send a chunk of input tesors and receive a chunk of outputs
+        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
+        """
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+            assert prompts.shape[0] == self.num_blocks
+            assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
+            assert prompts.shape[2] <= new_hidden_states.shape[1]
+            assert prompts.shape[3] == new_hidden_states.shape[2]
+
+        assert hypo_ids is None, "TODO implement hypo_ids here"
+        hypo_ids = torch.arange(len(new_hidden_states))
+
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states, prompts, torch.arange(len(new_hidden_states)))
+        inputs = (new_hidden_states, prompts, hypo_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
@@ -161,12 +181,16 @@ class RemoteSequentialInferenceSession:
 
         return self
 
-    def step(self, inputs: torch.Tensor):
+    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
         assert not self.closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+        if prompts is None or is_dummy(prompts):
+            prompts = DUMMY
+        else:
+            assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
         for session in self.inference_sessions:
-            outputs = session.step(inputs)
+            outputs = session.step(inputs, prompts[self.chosen_spans[0].start: self.chosen_spans[0].end], **kwargs)
             assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs