OpenCV加載Pytorch模型出現Unsupported Lua type 解決方法
原因
Torch有兩個版本,一個就叫Torch一個專門給Python用的Pytorch,它們訓練完之后保存下來的模型是不一樣的.
說到這問題就很清楚了.OpenCV的ReadNetFromTorch
支持的是前者...
解決方法
那么有沒有解決辦法呢,答案是有的.
PyTorch支持把模型保存為ONNX格式.而這個格式在opencv是支持的.
操作如下:
import torch
import torch.onnx
from torch.autograd import Variable
# ~~~~~~~~~~~~~~~~初始化與訓練模型過程~~~~~~~~~~~~~
# 這是普通的pytorch模型保存方式:
torch.save(net.state_dict(), "torch.pt")
# 這是保存為ONNX的方法:
# 由於PyTorch的模型,是動態調整大小的,這里需要初始化一個指定格式的數據,用來調整模型大小
# 就是和你訓練模型的時候用的數據一樣的格式就行
dummy_input = Variable(torch.randn(1, 1, 28, 28)).to(device)
# 保存模型
torch.onnx.export(net, dummy_input, "torch.onnx")
注意,這里還有個坑!
雖然模型保存成了ONNX格式
,但是OpenCV的ReadTensorFromONNX
並不能加載! 需要用ReadNet
方法加載!