Aleksandr Borzunov 3 éve
szülő
commit
33f6bdfac2
2 módosított fájl, 9 hozzáadás és 4 törlés
  1. 3 4
      cli/run_server.py
  2. 6 0
      tests/test_block_exact_match.py

+ 3 - 4
cli/run_server.py

@@ -98,11 +98,10 @@ def main():
     compression_type = args.pop("compression")
     compression = getattr(CompressionType, compression_type)
 
-    attn_cache_size = args.pop("attention_cache_bytes")
+    attn_cache_size = args.pop("attn_cache_size")
     if attn_cache_size is not None:
-        attention_cache_bytes = parse_size_as_bytes(attn_cache_size)
-    assert isinstance(
-        attn_cache_size, (int, type(None))
+        attn_cache_size = parse_size_as_bytes(attn_cache_size)
+    assert isinstance(attn_cache_size, (int, type(None))
     ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
 
     use_auth_token = args.pop("use_auth_token")

+ 6 - 0
tests/test_block_exact_match.py

@@ -37,3 +37,9 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
 
         assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
         assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
+
+        # test that max length is respected
+        with remote_block.inference_session(max_length=inputs.shape[1] - 1) as sess:
+            for i in range(inputs.shape[1]):
+                sess.step(inputs[:, i : i + 1, :])
+