|
@@ -60,7 +60,7 @@ def main():
|
|
|
revision=args.revision,
|
|
|
torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
|
|
|
load_in_8bit=args.torch_dtype == "int8",
|
|
|
- device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
|
|
|
+ device_map="auto" if args.torch_dtype == "int8" else None,
|
|
|
)
|
|
|
if args.torch_dtype == "int8":
|
|
|
# trigger weight quantization
|