Răsfoiți Sursa

forward metadata in sequential inference session

justheuristic 3 ani în urmă
părinte
comite
d6053a2516
2 a modificat fișierele cu 2 adăugiri și 5 ștergeri
  1. 1 4
      src/client/inference_session.py
  2. 1 1
      src/client/remote_generation.py

+ 1 - 4
src/client/inference_session.py

@@ -153,10 +153,7 @@ class RemoteSequentialInferenceSession:
             span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
             span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
             inference_session = RemoteExpertWorker.run_coroutine(
             inference_session = RemoteExpertWorker.run_coroutine(
                 RemoteTransformerBlockInferenceSession._create(
                 RemoteTransformerBlockInferenceSession._create(
-                    stub,
-                    span_uids,
-                    rpc_info=self.sequence_manager.rpc_info,
-                    timeout=self.timeout,
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
                 )
                 )
             )
             )
             self.inference_sessions.append(inference_session)
             self.inference_sessions.append(inference_session)

+ 1 - 1
src/client/remote_generation.py

@@ -61,7 +61,7 @@ class RemoteGenerationMixin:
             model_kwargs.get("stopping_criteria", None) is None
             model_kwargs.get("stopping_criteria", None) is None
         ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
         ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
         if inputs is not None:
         if inputs is not None:
-            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, "inputs must be a 3d tensor [batch, len, hid]"
+            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
         prefix_length = 0 if inputs is None else inputs.size(1)
         prefix_length = 0 if inputs is None else inputs.size(1)
 
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id