Explorar o código

check for past key values properly

justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
eea6fbb318
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      src/bloom/model.py

+ 1 - 1
src/bloom/model.py

@@ -235,7 +235,7 @@ class BloomModel(BloomPreTrainedModel):
 
         # Compute alibi tensor: check build_alibi_tensor documentation
         current_sequence_length = hidden_states.shape[1]
-        if past_key_values is not None and past_key_values[0] is not None:
+        if past_key_values and past_key_values[0]:
             current_sequence_length += past_key_values[0][0].shape[1]
         alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)