|
@@ -195,6 +195,8 @@ class Server(threading.Thread):
|
|
)
|
|
)
|
|
|
|
|
|
if load_in_8bit:
|
|
if load_in_8bit:
|
|
|
|
+ dtype = block.input_layernorm.weight.dtype
|
|
|
|
+ assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
|
|
block = replace_8bit_linear(block)
|
|
block = replace_8bit_linear(block)
|
|
|
|
|
|
block = block.to(device)
|
|
block = block.to(device)
|