Aleksandr Borzunov 2 years ago
parent
commit
bf0be9f031
1 changed files with 2 additions and 4 deletions
  1. 2 4
      src/server/handler.py

+ 2 - 4
src/server/handler.py

@@ -321,15 +321,13 @@ class TransformerConnectionHandler(ConnectionHandler):
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
 
-                descr = TensorDescriptor(
-                    size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
-                )
+                descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]
 
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
                 total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
 
-            gib = 1024 ** 3
+            gib = 1024**3
             if backend is not None:
                 cur_size = backend.memory_cache.current_size_bytes
                 max_size = backend.memory_cache.max_size_bytes