Procházet zdrojové kódy

fix bug when cache assignment surpasses max length

justheuristic před 3 roky
rodič
revize
aa769ce846
1 změnil soubory, kde provedl 4 přidání a 5 odebrání
  1. 4 5
      src/server/backend.py

+ 4 - 5
src/server/backend.py

@@ -73,8 +73,8 @@ class TransformerBackend(ModuleBackend):
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                cache[:, :] = cache[:, hypo_ids]
-                layer_past = past_k, past_v = cache[0], cache[1]
+                cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
+                layer_past = past_k, past_v = cache[0, :prefix_length], cache[1, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
 
@@ -84,9 +84,8 @@ class TransformerBackend(ModuleBackend):
                 assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
                 assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
                 assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-                if new_length < cache.shape[1]:
-                    cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
-                    cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
+                cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
+                cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
                 return (hidden_states,)
 
     def get_pools(self) -> Sequence[TaskPool]: