浏览代码

add assert on dtype for load_in_8bit

dbaranchuk 3 年之前
父节点
当前提交
b8a78b8254
共有 1 个文件被更改,包括 2 次插入0 次删除
  1. 2 0
      src/server/server.py

+ 2 - 0
src/server/server.py

@@ -195,6 +195,8 @@ class Server(threading.Thread):
             )
 
             if load_in_8bit:
+                dtype = block.input_layernorm.weight.dtype
+                assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
                 block = replace_8bit_linear(block)
 
             block = block.to(device)