Bläddra i källkod

add assert on dtype for load_in_8bit

dbaranchuk 3 år sedan
förälder
incheckning
b8a78b8254
1 ändrade filer med 2 tillägg och 0 borttagningar
  1. 2 0
      src/server/server.py

+ 2 - 0
src/server/server.py

@@ -195,6 +195,8 @@ class Server(threading.Thread):
             )
 
             if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
                 block = replace_8bit_linear(block)
 
             block = block.to(device)