@@ -195,6 +195,8 @@ class Server(threading.Thread):
)
if load_in_8bit:
+ dtype = block.input_layernorm.weight.dtype
+ assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
block = replace_8bit_linear(block)
block = block.to(device)