|
@@ -28,10 +28,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
attention_cache_handle = int(cache_metadata[0, 0].item())
|
|
|
prefix_length = int(cache_metadata[0, 1].item())
|
|
|
- (
|
|
|
- hidden_states,
|
|
|
- *_,
|
|
|
- ) = inputs # todo: this ignores any extra inputs for now; in future, it would be best to support attention mask as an extra input
|
|
|
+ hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here
|
|
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
|
|
|
with self.memory_cache.use_cache(attention_cache_handle) as cache:
|