Unet網絡
Unet是一種編碼-解碼結構相結合的神經網絡結構,是一種語義分割網絡。在醫學圖像分割的相關應用中被廣泛使用。使用matlab可以快速實現網絡結構的定義和訓練。
數據集准備
准備待訓練圖像和相對應的標注圖像,將圖像和標注圖像分別存放到不同的目錄中,通過相同的文件名進行一一對應。
%% 數據集加載
dataSetDir = fullfile('./data');
imageDir = fullfile(dataSetDir,'trainingImages');
labelDir = fullfile(dataSetDir,'trainingLabels');
定義像素分類的類別名稱,以及各類別在標注圖像中的亮度值
classNames = ["triangle","background"];
labelIDs = [255 0];
生成訓練數據集對象
imds = imageDatastore(imageDir);
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);
% ds = pixelLabelImageDatastore(imds,pxds);
ds = combine(imds,pxds);
網絡定義
imageSize = [32 32];
numClasses = 2;
lgraph = unetLayers(imageSize, numClasses)
訓練網絡
options = trainingOptions('sgdm', ...
'InitialLearnRate',1e-3, ...
'MaxEpochs',20, ...
'VerboseFrequency',10);
net = trainNetwork(ds,lgraph,options)
導出ONNX格式的模型,可使用opencv或tensorrt等工具進行應用部署
exportONNXNetwork(net,'myunet.onnx');
測試
pic = imread('.\data\testImages\image_002.jpg');
out2 = predict(net,pic);
subplot(1,2,1)
imshow(pic)
subplot(1,2,2)
imshow(out2(:,:,1))
完成代碼和測試數據
https://download.csdn.net/download/Ango_/16138054