1.掌握pytorch模型转换到onnx模型
2.顺利运行onnx模型
3.比对onnx模型和pytorch模型的输出结果
前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装
pytorch 转 onnx 只需要一个函数 torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数说明:
import torch import torch.nn import onnx model = torch.load('best.pt') model.eval() input_names = ['input'] output_names = ['output'] x = torch.randn(1,3,32,32,requires_grad=True) torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')
检查onnx模型,并使用onnxruntime运行。
import onnx import onnxruntime as ort model = onnx.load('best.onnx') onnx.checker.check_model(model) session = ort.InferenceSession('best.onnx') x=np.random.randn(1,3,32,32).astype(np.float32) # 注意输入type一定要np.float32!!!!! # x= torch.randn(batch_size,chancel,h,w) outputs = session.run(None,input = { 'input' : x })
参数说明:
import numpy as np np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)
如前所述,经验表明,ONNX 模型的运行效率明显优于原 PyTorch 模型,这似乎是源于 ONNX 模型生成过程中的优化,这也导致了模型的生成过程比较耗时,但整体效率依旧可观。
此外,根据对 ONNX 模型和 PyTorch 模型运行结果的统计分析(误差的均值和标准差),可以看出 ONNX 模型的运行结果误差很小、基本可靠。
内容参考:https://zhuanlan.zhihu.com/p/422290231