浏览代码

save non-transformer params separately

justheuristic 3 年之前
父节点
当前提交
324ea2dc96
共有 1 个文件被更改,包括 4 次插入0 次删除
  1. 4 0
      cli/quantize_for_cpu.py

+ 4 - 0
cli/quantize_for_cpu.py

@@ -47,3 +47,7 @@ if __name__ == "__main__":
             layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
         )
         torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
+
+    model.transformer.h = torch.nn.ModuleList()
+    torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))
+