Browse Source

Set device_map only for int8

Max Ryabinin 2 years ago
parent
commit
556f0fabe0
1 changed files with 1 additions and 1 deletions
  1. 1 1
      src/petals/cli/convert_model.py

+ 1 - 1
src/petals/cli/convert_model.py

@@ -60,7 +60,7 @@ def main():
         revision=args.revision,
         torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
         load_in_8bit=args.torch_dtype == "int8",
-        device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
+        device_map="auto" if args.torch_dtype == "int8" else None,
     )
     if args.torch_dtype == "int8":
         # trigger weight quantization