|
@@ -202,7 +202,6 @@ class Server(threading.Thread):
|
|
|
|
|
|
if load_in_8bit:
|
|
if load_in_8bit:
|
|
dtype = block.input_layernorm.weight.dtype
|
|
dtype = block.input_layernorm.weight.dtype
|
|
- assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
|
|
|
|
block = replace_8bit_linear(block)
|
|
block = replace_8bit_linear(block)
|
|
|
|
|
|
block = block.to(device)
|
|
block = block.to(device)
|