|
@@ -78,6 +78,7 @@ class _ServerInferenceSession:
|
|
|
def step(
|
|
|
self,
|
|
|
new_hidden_states: torch.Tensor,
|
|
|
+ attention_mask: torch.Tensor,
|
|
|
prompts: Optional[torch.Tensor] = None,
|
|
|
hypo_ids: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
@@ -104,7 +105,7 @@ class _ServerInferenceSession:
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
# serialize inputs and put them into the queue
|
|
|
- inputs = (new_hidden_states, prompts, hypo_ids)
|
|
|
+ inputs = (new_hidden_states, attention_mask, prompts, hypo_ids)
|
|
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
self._step(
|
|
|
runtime_pb2.ExpertRequest(
|
|
@@ -212,7 +213,9 @@ class InferenceSession:
|
|
|
assert not self._closed and not self._chosen_spans
|
|
|
return self
|
|
|
|
|
|
- def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
+ def step(
|
|
|
+ self, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs
|
|
|
+ ) -> torch.Tensor:
|
|
|
assert not self._closed
|
|
|
if torch.is_grad_enabled():
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
@@ -226,6 +229,7 @@ class InferenceSession:
|
|
|
inputs_device = inputs.device
|
|
|
inputs_dtype = inputs.dtype
|
|
|
inputs = inputs.cpu()
|
|
|
+ attention_mask = attention_mask.cpu()
|
|
|
prompts = prompts.cpu()
|
|
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
@@ -294,7 +298,7 @@ class InferenceSession:
|
|
|
else:
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
- outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
|
|
|
+ outputs = session.step(inputs, attention_mask, prompts[span.start : span.end], **kwargs)
|
|
|
assert (
|
|
|
inputs.shape == outputs.shape
|
|
|
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
|