|
@@ -15,16 +15,17 @@ from petals.client import DistributedBloomConfig
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
|
-
|
|
|
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, int8=torch.int8, auto="auto")
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
|
|
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
|
|
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
|
|
- parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
|
|
- parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
|
+ parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
|
|
|
+ help="Load initial model in this dtype")
|
|
|
+ parser.add_argument("--output_path", type=str, default="./converted_model",
|
|
|
+ help="Track output repo to this folder")
|
|
|
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
|
|
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
|
|
parser.add_argument(
|
|
@@ -41,7 +42,6 @@ def main():
|
|
|
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
|
|
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
|
|
|
|
|
- assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
|
|
if os.path.exists(args.output_path) and (
|
|
|
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
|
|
):
|
|
@@ -54,8 +54,15 @@ def main():
|
|
|
config.dht_prefix = args.output_repo
|
|
|
|
|
|
model = BloomModel.from_pretrained(
|
|
|
- args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
|
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision,
|
|
|
+ torch_dtype=DTYPE_MAP[args.torch_dtype] if args.torch_dtype != "int8" else "float16",
|
|
|
+ load_in_8bit=args.torch_dtype == "int8",
|
|
|
+ device_map={"word_embeddings": "cuda", "word_embeddings_layernorm": "cuda", "h": "cuda", "ln_f": "cuda"}
|
|
|
)
|
|
|
+ if args.torch_dtype == "int8":
|
|
|
+ # trigger weight quantization
|
|
|
+ model = model.cuda()
|
|
|
+
|
|
|
if args.resize_token_embeddings:
|
|
|
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
|
|
model.resize_token_embeddings(args.resize_token_embeddings)
|