針對回歸訓練卷積神經網絡
此示例說明如何使用卷積神經網絡擬合回歸模型來預測手寫數字的旋轉角度。
卷積神經網絡(CNN 或 ConvNet)是深度學習的基本工具,尤其適用於分析圖像數據。例如,您可以使用 CNN 對圖像進行分類。要預測連續數據(例如角度和距離),可以在網絡末尾包含回歸層。
該示例構造一個卷積神經網絡架構,訓練網絡,並使用經過訓練的網絡預測手寫數字的旋轉角度。這些預測對於光學字符識別很有用。
此外,您可以選擇使用 imrotate
(Image Processing Toolbox™) 旋轉圖像,並可選擇使用 boxplot
(Statistics and Machine Learning Toolbox™) 創建殘差箱線圖。
加載數據
數據集包含手寫數字的合成圖像以及每個圖像的旋轉角度(以度為單位)。
使用 digitTrain4DArrayData
和 digitTest4DArrayData
以四維數組的形式加載訓練圖像和驗證圖像。輸出 YTrain
和 YValidation
是以度為單位的旋轉角度。訓練數據集和驗證數據集各包含 5000 個圖像。
[XTrain,~,YTrain] = digitTrain4DArrayData; [XValidation,~,YValidation] = digitTest4DArrayData;
使用 imshow
顯示 20 個隨機訓練圖像。
numTrainImages = numel(YTrain); figure idx = randperm(numTrainImages,20); for i = 1:numel(idx) subplot(4,5,i) imshow(XTrain(:,:,:,idx(i))) drawnow end
檢查數據歸一化
在訓練神經網絡時,最好確保數據在網絡的所有階段均歸一化。對於使用梯度下降的網絡訓練,歸一化有助於訓練的穩定和加速。如果您的數據比例不佳,則損失可能會變為 NaN
,並且網絡參數在訓練過程中可能發生偏離。歸一化數據的常用方法包括重新縮放數據,使其范圍變為 [0,1],或使其均值為 0 且標准差為 1。您可以歸一化以下數據:
-
輸入數據。在將預測變量輸入到網絡之前對其進行歸一化。在此示例中,輸入圖像已歸一化到范圍 [0,1]。
-
層輸出。您可以使用批量歸一化層來歸一化每個卷積層和全連接層的輸出。
可直接聯系客服QQ交代需求:953586085
-
響應。如果使用批量歸一化層來歸一化網絡末尾的層輸出,則網絡的預測值在訓練開始時就被歸一化。如果響應的比例與這些預測值完全不同,則網絡訓練可能無法收斂。如果您的響應比例不佳,則嘗試對其進行歸一化,並查看網絡訓練是否有所改善。如果在訓練之前將響應歸一化,則必須轉換經過訓練網絡的預測值,以獲得原始響應的預測值。
繪制響應的分布。響應(以度為單位的旋轉角度)大致均勻地分布在 -45 和 45 之間,效果很好,無需歸一化。在分類問題中,輸出是類概率,始終需要歸一化。
figure histogram(YTrain) axis tight ylabel('Counts') xlabel('Rotation Angle')
通常,數據不必完全歸一化。但是,如果在此示例中訓練網絡來預測 100*YTrain
或 YTrain+500
而不是 YTrain
,則損失將變為 NaN
,並且網絡參數在訓練開始時會發生偏離。即使預測 aY + b 的網絡與預測 Y 的網絡之間的唯一差異是對最終全連接層的權重和偏置的簡單重新縮放,也會出現這些結果。
如果輸入或響應的分布非常不均勻或偏斜,您還可以在訓練網絡之前對數據執行非線性變換(例如,取其對數)。
創建網絡層
要解決回歸問題,請創建網絡層並在網絡末尾包含一個回歸層。
第一層定義輸入數據的大小和類型。輸入圖像的大小為 28×28×1。創建與訓練圖像大小相同的圖像輸入層。
網絡的中間層定義網絡的核心架構,大多數計算和學習都在此處進行。
最終層定義輸出數據的大小和類型。對於回歸問題,全連接層必須位於網絡末尾的回歸層之前。創建一個大小為 1 的全連接輸出層以及一個回歸層。
在 Layer
數組中將所有層組合在一起。
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer dropoutLayer(0.2) fullyConnectedLayer(1) regressionLayer];
訓練網絡
創建網絡訓練選項。進行 30 輪訓練。將初始學習率設置為 0.001,並在 20 輪訓練后降低學習率。通過指定驗證數據和驗證頻率,監控訓練過程中的網絡准確度。軟件基於訓練數據訓練網絡,並在訓練過程中按固定時間間隔計算基於驗證數據的准確度。驗證數據不用於更新網絡權重。打開訓練進度圖,關閉命令行窗口輸出。
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',30, ... 'InitialLearnRate',1e-3, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.1, ... 'LearnRateDropPeriod',20, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'Verbose',false);
使用 trainNetwork
創建網絡。如果存在兼容的 GPU,此命令會使用 GPU。否則,trainNetwork
將使用 CPU。在 GPU 上進行訓練需要具有 3.0 或更高計算能力的支持 CUDA® 的 NVIDIA® GPU。
net = trainNetwork(XTrain,YTrain,layers,options);
檢查 net
的 Layers
屬性中包含的網絡架構的詳細信息。
net.Layers
ans = 18x1 Layer array with layers: 1 'imageinput' Image Input 28x28x1 images with 'zerocenter' normalization 2 'conv_1' Convolution 8 3x3x1 convolutions with stride [1 1] and padding 'same' 3 'batchnorm_1' Batch Normalization Batch normalization with 8 channels 4 'relu_1' ReLU ReLU 5 'avgpool_1' Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 6 'conv_2' Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'avgpool_2' Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'conv_4' Convolution 32 3x3x32 convolutions with stride [1 1] and padding 'same' 14 'batchnorm_4' Batch Normalization Batch normalization with 32 channels 15 'relu_4' ReLU ReLU 16 'dropout' Dropout 20% dropout 17 'fc' Fully Connected 1 fully connected layer 18 'regressionoutput' Regression Output mean-squared-error with response 'Response'
測試網絡
基於驗證數據評估准確度來測試網絡性能。
使用 predict
預測驗證圖像的旋轉角度。
YPredicted = predict(net,XValidation);
評估性能
通過計算以下值來評估模型性能:
-
在可接受誤差界限內的預測值的百分比
-
預測旋轉角度和實際旋轉角度的均方根誤差 (RMSE)
計算預測旋轉角度和實際旋轉角度之間的預測誤差。
predictionError = YValidation - YPredicted;
計算在實際角度的可接受誤差界限內的預測值的數量。將閾值設置為 10 度。計算此閾值范圍內的預測值的百分比。
thr = 10; numCorrect = sum(abs(predictionError) < thr); numValidationImages = numel(YValidation); accuracy = numCorrect/numValidationImages
accuracy = 0.9636
使用均方根誤差 (RMSE) 來衡量預測旋轉角度和實際旋轉角度之間的差異。
squares = predictionError.^2; rmse = sqrt(mean(squares))
rmse = single
4.6393
顯示每個數字類的殘差箱線圖
boxplot
函數需要一個矩陣,其中各個列對應於各個數字類的殘差。
驗證數據按數字類 0-9 對圖像進行分組,每組包含 500 個樣本。使用 reshape
按數字類對殘差進行分組。
residualMatrix = reshape(predictionError,500,10);
residualMatrix
的每列對應於每個數字的殘差。使用 boxplot
(Statistics and Machine Learning Toolbox) 為每個數字創建殘差箱線圖。
figure boxplot(residualMatrix,... 'Labels',{'0','1','2','3','4','5','6','7','8','9'}) xlabel('Digit Class') ylabel('Degrees Error') title('Residuals')
准確度最高的數字類具有接近於零的均值和很小的方差。
校正數字旋轉
您可以使用 Image Processing Toolbox 中的函數來擺正數字並將它們顯示在一起。使用 imrotate
(Image Processing Toolbox) 根據預測的旋轉角度旋轉 49 個樣本數字。
idx = randperm(numValidationImages,49); for i = 1:numel(idx) image = XValidation(:,:,:,idx(i)); predictedAngle = YPredicted(idx(i)); imagesRotated(:,:,:,i) = imrotate(image,predictedAngle,'bicubic','crop'); end
顯示原始數字以及校正旋轉后的數字。您可以使用 montage
(Image Processing Toolbox) 將數字顯示在同一個圖像上。
figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title('Original') subplot(1,2,2) montage(imagesRotated) title('Corrected')