|
@@ -47,3 +47,7 @@ if __name__ == "__main__":
|
|
|
layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
|
|
|
)
|
|
|
torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
|
|
|
+
|
|
|
+ model.transformer.h = torch.nn.ModuleList()
|
|
|
+ torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))
|
|
|
+
|