justheuristic 3 лет назад
Родитель
Сommit
15d0ea7129
1 измененных файлов с 1 добавлено и 4 удалено
  1. 1 4
      src/server/backend.py

+ 1 - 4
src/server/backend.py

@@ -28,10 +28,7 @@ class TransformerBackend(ModuleBackend):
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         attention_cache_handle = int(cache_metadata[0, 0].item())
         prefix_length = int(cache_metadata[0, 1].item())
-        (
-            hidden_states,
-            *_,
-        ) = inputs  # todo: this ignores any extra inputs for now; in future, it would be best to support attention mask as an extra input
+        hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
         with self.memory_cache.use_cache(attention_cache_handle) as cache: