Browse Source

Force use_cache=True in config only (#497)

This reverts a part of #496 and instead overrides `use_cache` in `LlamaConfig`s only (so the correct value is visible by HF `.generate()` as well).
Alexander Borzunov 1 năm trước cách đây
mục cha
commit
b4d822afb2

+ 2 - 1
src/petals/models/bloom/model.py

@@ -43,7 +43,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         attention_mask: Optional[torch.Tensor] = None,
         head_mask: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.LongTensor] = None,
-        use_cache: Optional[bool] = None,  # Not used here but needed for HF Transformers compatibility
+        use_cache: Optional[bool] = None,
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
@@ -63,6 +63,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
         assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
         assert not output_attentions, f"{output_attentions=} is not supported"
         assert not output_hidden_states, f"{output_hidden_states=} is not supported"
         assert return_dict is None or return_dict, f"{return_dict=} is not supported"

+ 1 - 0
src/petals/models/llama/config.py

@@ -43,4 +43,5 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
         result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
         config = result[0] if isinstance(result, tuple) else result
         config.pretraining_tp = 1  # This may give less accurate results but it doesn't matter if we use quantization
+        config.use_cache = True  # use_cache=False leads to identical results but is slower and not supported by Petals
         return result

+ 2 - 1
src/petals/models/llama/model.py

@@ -43,7 +43,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         position_ids: Optional[torch.LongTensor] = None,
         past_key_values: Optional[RemotePastKeyValues] = None,
         inputs_embeds: Optional[torch.FloatTensor] = None,
-        use_cache: Optional[bool] = None,  # Not used here but needed for HF Transformers compatibility
+        use_cache: Optional[bool] = None,
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
@@ -65,6 +65,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
         assert not output_attentions, f"{output_attentions=} is not supported"
         assert not output_hidden_states, f"{output_hidden_states=} is not supported"
         assert return_dict is None or return_dict, f"{return_dict=} is not supported"