瀏覽代碼

add assert on dtype for load_in_8bit

dbaranchuk 3 年之前
父節點
當前提交
b8a78b8254
共有 1 個文件被更改,包括 2 次插入0 次删除
  1. 2 0
      src/server/server.py

+ 2 - 0
src/server/server.py

@@ -195,6 +195,8 @@ class Server(threading.Thread):
             )
             )
 
 
             if load_in_8bit:
             if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
                 block = replace_8bit_linear(block)
                 block = replace_8bit_linear(block)
 
 
             block = block.to(device)
             block = block.to(device)