ResNet-50模型圖像分類示例


ResNet-50模型圖像分類示例

概述

計算機視覺是當前深度學習研究最廣泛、落地最成熟的技術領域,在手機拍照、智能安防、自動駕駛等場景有廣泛應用。從2012年AlexNet在ImageNet比賽奪冠以來,深度學習深刻推動了計算機視覺領域的發展,當前最先進的計算機視覺算法幾乎都是深度學習相關的。深度神經網絡可以逐層提取圖像特征,並保持局部不變性,被廣泛應用於分類、檢測、分割、檢索、識別、提升、重建等視覺任務中。

本文結合圖像分類任務,介紹MindSpore如何應用於計算機視覺場景。

圖像分類

圖像分類是最基礎的計算機視覺應用,屬於有監督學習類別。給定一張數字圖像,判斷圖像所屬的類別,如貓、狗、飛機、汽車等等。用函數來表示這個過程如下:

   label = model(image)
   return label

選擇合適的model是關鍵。這里的model一般指的是深度卷積神經網絡,如AlexNet、VGG、GoogLeNet、ResNet等等。

MindSpore實現了典型的卷積神經網絡,開發者可以參考model_zoo

MindSpore當前支持的圖像分類網絡包括:典型網絡LeNet、AlexNet、ResNet。

任務描述及准備

 

圖1:CIFAR-10數據集[1]

如圖1所示,CIFAR-10數據集共包含10類、共60000張圖片。其中,每類圖片6000張,50000張是訓練集,10000張是測試集。每張圖片大小為32*32。

圖像分類的訓練指標通常是精度(Accuracy),即正確預測的樣本數占總預測樣本數的比值。

接下來介紹利用MindSpore解決圖片分類任務,整體流程如下:

  1. 下載CIFAR-10數據集
  2. 數據加載和預處理
  3. 定義卷積神經網絡,本例采用ResNet-50網絡
  4. 定義損失函數和優化器
  5. 調用Model高階API進行訓練和保存模型文件
  6. 加載保存的模型進行推理

本例面向Ascend 910 AI處理器硬件平台,你可以在這里下載完整的樣例代碼:https://gitee.com/mindspore/docs/tree/r1.1/tutorials/tutorial_code/resnet

下面對任務流程中各個環節及代碼關鍵片段進行解釋說明。

下載CIFAR-10數據集

先從CIFAR-10數據集官網上下載CIFAR-10數據集。本例中采用binary格式的數據,Linux環境可以通過下面的命令下載:

接下來需要解壓數據集,解壓命令如下:

數據預加載和預處理

  1. 加載數據集

數據加載可以通過內置數據集格式Cifar10Dataset接口完成。

Cifar10Dataset,讀取類型為隨機讀取,內置CIFAR-10數據集,包含圖像和標簽,圖像格式默認為uint8,標簽數據格式默認為uint32。更多說明請查看API中Cifar10Dataset接口說明。

數據加載代碼如下,其中data_home為數據存儲位置:

  1. 數據增強

數據增強主要是對數據進行歸一化和豐富數據樣本數量。常見的數據增強方式包括裁剪、翻轉、色彩變化等等。MindSpore通過調用map方法在圖片上執行增強操作:

resize_width = 224
rescale = 1.0 / 255.0
shift = 0.0
 
# define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = C.RandomHorizontalFlip()
resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = C.HWC2CHW()
type_cast_op = C2.TypeCast(mstype.int32)
 
c_trans = []
if training:
    c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
 
# apply map operations on images
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")
  1. 數據混洗和批處理

最后通過數據混洗(shuffle)隨機打亂數據的順序,並按batch讀取數據,進行模型訓練:

cifar_ds = cifar_ds.shuffle(buffer_size=10)
 
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
 
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)

定義卷積神經網絡

卷積神經網絡已經是圖像分類任務的標准算法了。卷積神經網絡采用分層的結構對圖片進行特征提取,由一系列的網絡層堆疊而成,比如卷積層、池化層、激活層等等。

ResNet通常是較好的選擇。首先,它足夠深,常見的有34層,50層,101層。通常層次越深,表征能力越強,分類准確率越高。其次,可學習,采用了殘差結構,通過shortcut連接把低層直接跟高層相連,解決了反向傳播過程中因為網絡太深造成的梯度消失問題。此外,ResNet網絡的性能很好,既表現為識別的准確率,也包括它本身模型的大小和參數量。

MindSpore Model Zoo中已經實現了ResNet模型,可以采用ResNet-50。調用方法如下:

定義損失函數和優化器

接下來需要定義損失函數(Loss)和優化器(Optimizer)。損失函數是深度學習的訓練目標,也叫目標函數,可以理解為神經網絡的輸出(Logits)和標簽(Labels)之間的距離,是一個標量數據。

常見的損失函數包括均方誤差、L2損失、Hinge損失、交叉熵等等。圖像分類應用通常采用交叉熵損失(CrossEntropy)。

優化器用於神經網絡求解(訓練)。由於神經網絡參數規模龐大,無法直接求解,因而深度學習中采用隨機梯度下降算法(SGD)及其改進算法進行求解。MindSpore封裝了常見的優化器,如SGD、ADAM、Momemtum等等。本例采用Momentum優化器,通常需要設定兩個參數,動量(moment)和權重衰減項(weight decay)。

MindSpore中定義損失函數和優化器的代碼樣例如下:

ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
 
# optimization definition
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)

調用Model高階API進行訓練和保存模型文件

完成數據預處理、網絡定義、損失函數和優化器定義之后,就可以進行模型訓練了。模型訓練包含兩層迭代,數據集的多輪迭代(epoch)和一輪數據集內按分組(batch)大小進行的單步迭代。其中,單步迭代指的是按分組從數據集中抽取數據,輸入到網絡中計算得到損失函數,然后通過優化器計算和更新訓練參數的梯度。

為了簡化訓練過程,MindSpore封裝了Model高階接口。用戶輸入網絡、損失函數和優化器完成Model的初始化,然后調用train接口進行訓練,train接口參數包括迭代次數(epoch)和數據集(dataset)。

模型保存是對訓練參數進行持久化的過程。Model類中通過回調函數(callback)的方式進行模型保存,如下面代碼所示。用戶通過CheckpointConfig設置回調函數的參數,其中,save_checkpoint_steps指每經過固定的單步迭代次數保存一次模型,keep_checkpoint_max指最多保存的模型個數。

network, loss, optimizer are defined before.
batch_num, epoch_size are training parameters.
'''
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
 
# CheckPoint CallBack definition
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
 
# LossMonitor is used to print loss value on screen
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])

加載保存的模型,並進行驗證

訓練得到的模型文件(如resnet.ckpt)可以用來預測新圖像的類別。首先通過load_checkpoint加載模型文件。然后調用Model的eval接口預測新圖像類別。

load_param_into_net(net, param_dict)
eval_dataset = create_dataset(training=False)
res = model.eval(eval_dataset)
print("result: ", res)

參考文獻

[1] https://www.cs.toronto.edu/~kriz/cifar.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM