瀏覽代碼

Set device_map only for int8

Max Ryabinin 2 年之前
父節點
當前提交
556f0fabe0
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      src/petals/cli/convert_model.py

+ 1 - 1
src/petals/cli/convert_model.py

@@ -60,7 +60,7 @@ def main():
         revision=args.revision,
         torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
         load_in_8bit=args.torch_dtype == "int8",
-        device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"},
+        device_map="auto" if args.torch_dtype == "int8" else None,
     )
     if args.torch_dtype == "int8":
         # trigger weight quantization