@@ -202,7 +202,6 @@ class Server(threading.Thread):
if load_in_8bit:
dtype = block.input_layernorm.weight.dtype
- assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
block = replace_8bit_linear(block)
block = block.to(device)