我做的是一個簡單的二分類任務,用了個vgg11,因為要部署到應用,所以將 PyTorch 中定義的模型轉換為 ONNX 格式,然后在 ONNX Runtime 中運行它,那就不用了在機子上配pytorch環境了,然后也試過轉出來的onnx用opencv.dnn來呼叫,發現識別完全不對,據說是opencv的那個包只能做二維的pooling層,不能做三維的,
然后具體的模型轉換以及使用如下代碼所示,僅作為學習筆記咯~(親測可用)
pip install onnx
pip install onnxruntime
首先,將Pytorch模型轉成onnx格式,然后驗證一波onnx模型有沒有什么毛病
# coding=gbk
#_*_ coding=utf-8 _*_
import torch
import torchvision
import torch.nn as nn
from vgg import vgg11_bn
from torchvision import models
import time
out_onnx = 'model.onnx'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dummy = torch.randn(1, 1, 128, 128) # 模型的輸入格式
model = vgg11_bn() # 模型
model = nn.DataParallel(model) # 因為這里我用了多執行緒訓練,所以得加上
model.load_state_dict(torch.load('model.pth', map_location='cuda'))
if isinstance(model,torch.nn.DataParallel):
model = model.module
model = model.to(device)
dummy = dummy.to(device)
# 定義輸入的名字和輸出的名字,好像可有可無
input_names = ["input"]
output_names = ["output"]
# 輸出pytorch to onnx
torch_out = torch.onnx.export(model, dummy, out_onnx, input_names=input_names, output_names=output_names)
print("finish!") # 搞定
time.sleep(5)
# 驗證 Check the model
import onnx
onnx_model = onnx.load(out_onnx)
print('The model is:\n{}'.format(onnx_model))
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print('The model is invalid: %s' % e)
else:
print('The model is valid!')
output = onnx_model.graph.output
print(output)
然后就是用onnx在python不import torch的情況下做推理的測驗:
# coding=gbk
#_*_ coding=utf-8 _*_
import numpy as np
from PIL import Image
img_input = "test.jpg"
imagec = Image.open(img_input).convert("L") # 加載影像,我的是灰度圖
imagec = imagec.resize((128, 128)) # resize
# 我參考的代碼是先ToTensor,然后做ununsqueeze到(1, 1, 128, 128),
# 再to_numpy,像素值是0-1,而PIL的是0-255,所以這里除了個255,再做ununsqueeze
image = np.array(imagec,dtype=np.float32)/255.
image = np.expand_dims(image, axis=0)
image = np.expand_dims(image, axis=0)
print(image) # ---> (1, 1, 128, 128)
import onnxruntime
import onnx
##onnx測驗
onnx_model_path = "models.onnx"
session = onnxruntime.InferenceSession(onnx_model_path)
#compute ONNX Runtime output prediction
inputs = {session.get_inputs()[0].name: image}
logits = session.run(None, inputs)[0]
print("onnx weights", logits)
print("onnx prediction", logits.argmax(axis=1)[0])
參考了很多大佬的文章
https://blog.csdn.net/ouening/article/details/109245243
https://zhuanlan.zhihu.com/p/363177241
https://www.cnblogs.com/sddai/p/14537381.html
還有的大佬鏈接找不到了,感謝大佬們
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/320972.html
標籤:AI
上一篇:R語言Welch方差分析(Welch’s ANOVA)實戰:Welch方差分析是典型的單因素方差分析的一種替代方法,當方差相等的假設被違反時我們無法使用單因素方差分析,這時候Welch’s出來救場了
