瀏覽代碼

Add part of deepprompts

Artem Chumachenko 3 年之前
父節點
當前提交
88e6a75996
共有 4 個文件被更改,包括 17 次插入9 次删除
  1. 3 3
      src/client/inference_session.py
  2. 3 2
      src/client/remote_generation.py
  3. 10 3
      src/server/backend.py
  4. 1 1
      src/server/handler.py

+ 3 - 3
src/client/inference_session.py

@@ -69,19 +69,19 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor):
+    def step(self, new_hidden_states: torch.Tensor, prompts: Optional[torch.Tensor] = None):
         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states, torch.arange(len(new_hidden_states)))
+        inputs = (new_hidden_states, prompts, torch.arange(len(new_hidden_states)))
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                     metadata=self._serialized_metadata if not self.stepped else None,
                 )

+ 3 - 2
src/client/remote_generation.py

@@ -105,11 +105,12 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                intermediate_prompts = None
                 if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs)[:, -1]
+                hidden_state = sess.step(embs, prompts=intermediate_prompts)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 

+ 10 - 3
src/server/backend.py

@@ -55,13 +55,20 @@ class TransformerBackend(ModuleBackend):
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
-        self.inference_schema = (self.args_schema, self.kwargs_schema, BatchTensorDescriptor((), dtype=torch.int64))
+        self.inference_schema = (
+            (
+                *self.args_schema,
+                BatchTensorDescriptor((), dtype=self.dtype),
+                BatchTensorDescriptor((), dtype=torch.int64),
+            ),
+            self.kwargs_schema,
+        )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states, hypo_ids = inputs  # todo: in future, it would be best to support attention mask here
+            hidden_states, hypo_ids, prompts = inputs  # todo: in future, it would be best to support attention mask here
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
@@ -72,7 +79,7 @@ class TransformerBackend(ModuleBackend):
                 layer_past = past_k, past_v = cache[0, hypo_ids, arange], cache[1, hypo_ids, arange]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
-                    hidden_states, layer_past=layer_past, use_cache=True
+                    hidden_states, layer_past=layer_past, use_cache=True, prompts=prompts
                 )
 
                 # todo remove these asserts once we pass all tests

+ 1 - 1
src/server/handler.py

@@ -65,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
                     assert len(request.tensors) == 2, "Must specify hidden_states and input_ids" # TODO replace with schema
-                    hidden_states, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+                    hidden_states, intermediate_prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
                     if prefix_length + length_increment > max_length:
                         raise ValueError(