ResNet-50模型圖像分類示例
概述
計算機視覺是當前深度學習研究最廣泛、落地最成熟的技術領域,在手機拍照、智能安防、自動駕駛等場景有廣泛應用。從2012年AlexNet在ImageNet比賽奪冠以來,深度學習深刻推動了計算機視覺領域的發展,當前最先進的計算機視覺算法幾乎都是深度學習相關的。深度神經網絡可以逐層提取圖像特征,並保持局部不變性,被廣泛應用於分類、檢測、分割、檢索、識別、提升、重建等視覺任務中。
本文結合圖像分類任務,介紹MindSpore如何應用於計算機視覺場景。
圖像分類
圖像分類是最基礎的計算機視覺應用,屬於有監督學習類別。給定一張數字圖像,判斷圖像所屬的類別,如貓、狗、飛機、汽車等等。用函數來表示這個過程如下:
label = model(image)
return label
選擇合適的model是關鍵。這里的model一般指的是深度卷積神經網絡,如AlexNet、VGG、GoogLeNet、ResNet等等。
MindSpore實現了典型的卷積神經網絡,開發者可以參考model_zoo。
MindSpore當前支持的圖像分類網絡包括:典型網絡LeNet、AlexNet、ResNet。
任務描述及准備

圖1:CIFAR-10數據集[1]
如圖1所示,CIFAR-10數據集共包含10類、共60000張圖片。其中,每類圖片6000張,50000張是訓練集,10000張是測試集。每張圖片大小為32*32。
圖像分類的訓練指標通常是精度(Accuracy),即正確預測的樣本數占總預測樣本數的比值。
接下來介紹利用MindSpore解決圖片分類任務,整體流程如下:
- 下載CIFAR-10數據集
- 數據加載和預處理
- 定義卷積神經網絡,本例采用ResNet-50網絡
- 定義損失函數和優化器
- 調用Model高階API進行訓練和保存模型文件
- 加載保存的模型進行推理
本例面向Ascend 910 AI處理器硬件平台,你可以在這里下載完整的樣例代碼:https://gitee.com/mindspore/docs/tree/r1.1/tutorials/tutorial_code/resnet
下面對任務流程中各個環節及代碼關鍵片段進行解釋說明。
下載CIFAR-10數據集
先從CIFAR-10數據集官網上下載CIFAR-10數據集。本例中采用binary格式的數據,Linux環境可以通過下面的命令下載:
接下來需要解壓數據集,解壓命令如下:
數據預加載和預處理
- 加載數據集
數據加載可以通過內置數據集格式Cifar10Dataset接口完成。
Cifar10Dataset,讀取類型為隨機讀取,內置CIFAR-10數據集,包含圖像和標簽,圖像格式默認為uint8,標簽數據格式默認為uint32。更多說明請查看API中Cifar10Dataset接口說明。
數據加載代碼如下,其中data_home為數據存儲位置:
- 數據增強
數據增強主要是對數據進行歸一化和豐富數據樣本數量。常見的數據增強方式包括裁剪、翻轉、色彩變化等等。MindSpore通過調用map方法在圖片上執行增強操作:
resize_width = 224
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = C.RandomHorizontalFlip()
resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = C.HWC2CHW()
type_cast_op = C2.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
# apply map operations on images
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")
- 數據混洗和批處理
最后通過數據混洗(shuffle)隨機打亂數據的順序,並按batch讀取數據,進行模型訓練:
cifar_ds = cifar_ds.shuffle(buffer_size=10)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
定義卷積神經網絡
卷積神經網絡已經是圖像分類任務的標准算法了。卷積神經網絡采用分層的結構對圖片進行特征提取,由一系列的網絡層堆疊而成,比如卷積層、池化層、激活層等等。
ResNet通常是較好的選擇。首先,它足夠深,常見的有34層,50層,101層。通常層次越深,表征能力越強,分類准確率越高。其次,可學習,采用了殘差結構,通過shortcut連接把低層直接跟高層相連,解決了反向傳播過程中因為網絡太深造成的梯度消失問題。此外,ResNet網絡的性能很好,既表現為識別的准確率,也包括它本身模型的大小和參數量。
MindSpore Model Zoo中已經實現了ResNet模型,可以采用ResNet-50。調用方法如下:
定義損失函數和優化器
接下來需要定義損失函數(Loss)和優化器(Optimizer)。損失函數是深度學習的訓練目標,也叫目標函數,可以理解為神經網絡的輸出(Logits)和標簽(Labels)之間的距離,是一個標量數據。
常見的損失函數包括均方誤差、L2損失、Hinge損失、交叉熵等等。圖像分類應用通常采用交叉熵損失(CrossEntropy)。
優化器用於神經網絡求解(訓練)。由於神經網絡參數規模龐大,無法直接求解,因而深度學習中采用隨機梯度下降算法(SGD)及其改進算法進行求解。MindSpore封裝了常見的優化器,如SGD、ADAM、Momemtum等等。本例采用Momentum優化器,通常需要設定兩個參數,動量(moment)和權重衰減項(weight decay)。
MindSpore中定義損失函數和優化器的代碼樣例如下:
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# optimization definition
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
調用Model高階API進行訓練和保存模型文件
完成數據預處理、網絡定義、損失函數和優化器定義之后,就可以進行模型訓練了。模型訓練包含兩層迭代,數據集的多輪迭代(epoch)和一輪數據集內按分組(batch)大小進行的單步迭代。其中,單步迭代指的是按分組從數據集中抽取數據,輸入到網絡中計算得到損失函數,然后通過優化器計算和更新訓練參數的梯度。
為了簡化訓練過程,MindSpore封裝了Model高階接口。用戶輸入網絡、損失函數和優化器完成Model的初始化,然后調用train接口進行訓練,train接口參數包括迭代次數(epoch)和數據集(dataset)。
模型保存是對訓練參數進行持久化的過程。Model類中通過回調函數(callback)的方式進行模型保存,如下面代碼所示。用戶通過CheckpointConfig設置回調函數的參數,其中,save_checkpoint_steps指每經過固定的單步迭代次數保存一次模型,keep_checkpoint_max指最多保存的模型個數。
network, loss, optimizer are defined before.
batch_num, epoch_size are training parameters.
'''
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
# CheckPoint CallBack definition
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
# LossMonitor is used to print loss value on screen
loss_cb = LossMonitor()
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
加載保存的模型,並進行驗證
訓練得到的模型文件(如resnet.ckpt)可以用來預測新圖像的類別。首先通過load_checkpoint加載模型文件。然后調用Model的eval接口預測新圖像類別。
load_param_into_net(net, param_dict)
eval_dataset = create_dataset(training=False)
res = model.eval(eval_dataset)
print("result: ", res)
參考文獻
[1] https://www.cs.toronto.edu/~kriz/cifar.html
