Explorar o código

add assert on dtype for load_in_8bit

dbaranchuk %!s(int64=3) %!d(string=hai) anos
pai
achega
b8a78b8254
Modificáronse 1 ficheiros con 2 adicións e 0 borrados
  1. 2 0
      src/server/server.py

+ 2 - 0
src/server/server.py

@@ -195,6 +195,8 @@ class Server(threading.Thread):
             )
 
             if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
                 block = replace_8bit_linear(block)
 
             block = block.to(device)