|
@@ -22,20 +22,20 @@ class RemotePastKeyValues(Cache):
|
|
|
|
|
|
def __init__(self) -> None:
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.seen_tokens = 0
|
|
|
|
|
|
+ self._seen_tokens = 0
|
|
self.hypo_ids: Optional[torch.LongTensor] = None
|
|
self.hypo_ids: Optional[torch.LongTensor] = None
|
|
|
|
|
|
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
|
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
|
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
|
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
|
|
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
- return self.seen_tokens
|
|
|
|
|
|
+ return self._seen_tokens
|
|
|
|
|
|
def get_max_length(self) -> Optional[int]:
|
|
def get_max_length(self) -> Optional[int]:
|
|
return None
|
|
return None
|
|
|
|
|
|
def update_seen(self, new_seen: int) -> None:
|
|
def update_seen(self, new_seen: int) -> None:
|
|
- self.seen_tokens += new_seen
|
|
|
|
|
|
+ self._seen_tokens += new_seen
|
|
|
|
|
|
def reorder_cache(self, beam_idx):
|
|
def reorder_cache(self, beam_idx):
|
|
raise NotImplementedError("Beam search reordering is not implemented yet")
|
|
raise NotImplementedError("Beam search reordering is not implemented yet")
|