一、Type Error: Type 'tensor(bool)' of input parameter (121) of operator (ScatterND) in node (ScatterND_128) is invalid
問題
模型轉出成功后,用onnxruntime加載,出現不支持參數問題, 這里出現tensor(bool)是因為代碼中使用了bool類型的索引
解決措施
索引采用torch.where替代
...
mask = dist < distance
distance[mask] = dist[mask]
...
更改為
distance = torch.where(dist < distance, dist, distance)
二、FAIL : Load model from ./test.onnx failed:Fatal error: ATen is not a registered function/op
問題
模型轉出成功后,用onnxruntime加載,出現沒有注冊的算子
解決措施
在torch.onnx.export函數中設置opset_version=12
三、動態輸入/輸出
有時候輸入和輸出維度是變化的,這個時候在導出的時候可以添加dynamic_axes參數,並指定哪些參數和維度是動態的。

結果

四、Removing initializer 'bn1.num_batches_tracked'. It is not used by any node and should be removed from the model.
問題
模型轉出成功后,用onnxruntime運行出現以上警告
解決措施
對模型進行優化
import onnx
import onnxoptimizer # pip install onnxoptimizer
onnx_model = onnx.load(onnxfile)
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
optimized_model = onnxoptimizer.optimize(onnx_model, passes)
onnx.save(optimized_model, onnxfile)
