Преглед на файлове

add assert on dtype for load_in_8bit

dbaranchuk преди 3 години
родител
ревизия
b8a78b8254
променени са 1 файла, в които са добавени 2 реда и са изтрити 0 реда
  1. 2 0
      src/server/server.py

+ 2 - 0
src/server/server.py

@@ -195,6 +195,8 @@ class Server(threading.Thread):
             )
 
             if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
                 block = replace_8bit_linear(block)
 
             block = block.to(device)