|
@@ -31,7 +31,6 @@ class TransformerBackend(ModuleBackend):
|
|
current_sequence_length = int(cache_metadata[0, 1].item())
|
|
current_sequence_length = int(cache_metadata[0, 1].item())
|
|
with self.memory_cache.use_cache(attention_cache_handle) as cache:
|
|
with self.memory_cache.use_cache(attention_cache_handle) as cache:
|
|
print('METADATA:', cache_metadata, "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
|
|
print('METADATA:', cache_metadata, "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
|
|
- print(inputs[0].shape, cache.shape)
|
|
|
|
cache[...] += 1
|
|
cache[...] += 1
|
|
return (inputs[0] + cache.flatten()[0],)
|
|
return (inputs[0] + cache.flatten()[0],)
|
|
|
|
|