Fine-Tuning微調原理


Fine-Tuning微調原理

如何在只有60000張圖片的Fashion-MNIST訓練數據集中訓練模型。ImageNet,這是學術界使用最廣泛的大型圖像數據集,它擁有1000多萬幅圖像和1000多個類別的對象。然而,我們經常處理的數據集的大小通常比第一個大,但比第二個小。             

假設我們想在圖像中識別不同種類的椅子,然后將購買鏈接推給用戶。一種可行的方法是先找到一百張常見的椅子,每把椅子取一千張不同角度的圖像,然后在采集到的圖像數據集上訓練分類模型。雖然這個數據集可能比時尚MNIST大,但是示例的數量仍然不到ImageNet的十分之一。這可能導致適用於ImageNet的復雜模型在此數據集上過度擬合。同時,由於數據量有限,最終訓練出的模型精度可能達不到實際要求。             

為了解決上述問題,一個顯而易見的解決辦法就是收集更多的數據。然而,收集和標記數據會消耗大量的時間和金錢。例如,為了收集ImageNet的數據集,研究人員花費了數百萬美元的研究經費。盡管近年來,數據采集成本大幅下降,但成本仍然不容忽視。             

另一種解決方案是應用轉移學習將從源數據集學習的知識遷移到目標數據集。例如,雖然ImageNet中的圖像大多與椅子無關,但是在這個數據集上訓練的模型可以提取更一般的圖像特征,這些特征可以幫助識別邊緣、紋理、形狀和對象組成。這些相似的特征對於識別椅子同樣有效。             

在本節中,我們將介紹遷移學習中的一種常用技術:微調。如圖13.2.1所示,微調包括以下四個步驟:             

在源數據集(例如ImageNet數據集)上預訓練神經網絡模型,即源模型。             

建立一個新的神經網絡模型,即目標模型。這將復制源模型上的所有模型設計及其參數,輸出層除外。我們假設這些模型參數包含從源數據集學習到的知識,這些知識將同樣適用於目標數據集。我們還假設源模型的輸出層與源數據集的標簽密切相關,因此不在目標模型中使用。             

將輸出大小為目標數據集類別數的輸出層添加到目標模型中,並隨機初始化該層的模型參數。             

在目標數據集上訓練目標模型,例如椅子數據集。我們將從頭開始訓練輸出層,同時根據源模型的參數對所有剩余層的參數進行微調。

 

Fig. 1.  Fine tuning.

1. Hot Dog Recognition

我們將使用一個具體的例子來練習:熱狗識別。我們將基於一個小的數據集,對在ImageNet數據集上訓練的ResNet模型進行微調。這個小數據集包含數千張圖像,其中一些包含熱狗。我們將使用通過微調獲得的模型來識別圖像是否包含熱狗。              

首先,導入實驗所需的軟件包和模塊。Gluon的model_zoo package提供了一個通用的預訓練模型。如果你想獲得更多的計算機視覺的預先訓練模型,你可以使用GluonCV工具箱。

%matplotlib inline

from d2l import mxnet as d2l

from mxnet import gluon, init, np, npx

from mxnet.gluon import nn

import os

 npx.set_np()

1.1. Obtaining the Dataset

我們使用的熱狗數據集來自在線圖像,包含1400個熱狗的正面圖片和其他食物的相同數量的負面圖片。1000個各種課程的圖像用於訓練,其余的用於測試。             

我們首先下載壓縮數據集,得到兩個文件夾hotdog/train和hotdog/test。這兩個文件夾都有hotdog和not hotdog類別子文件夾,每個子文件夾都有相應的圖像文件。

#@save

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL+'hotdog.zip',

                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')

Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

我們創建兩個ImageFolderDataset實例,分別讀取訓練數據集和測試數據集中的所有圖像文件。

train_imgs = gluon.data.vision.ImageFolderDataset(

    os.path.join(data_dir, 'train'))

test_imgs = gluon.data.vision.ImageFolderDataset(

os.path.join(data_dir, 'test'))

前8個正面示例和最后8個負面圖像如下所示。如您所見,圖像的大小和縱橫比各不相同。

hotdogs = [train_imgs[i][0] for i in range(8)]

not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]

d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

 

在訓練過程中,我們首先從圖像中裁剪出一個大小和縱橫比隨機的隨機區域,然后將該區域縮放到一個高度和寬度為224像素的輸入。在測試過程中,我們將圖像的高度和寬度縮放到256像素,然后裁剪高寬為224像素的中心區域作為輸入。此外,我們規范化三個RGB(紅色、綠色和藍色)顏色通道的值。從每個值中減去信道所有值的平均值,然后將結果除以信道所有值的標准差,以產生輸出。

# We specify the mean and variance of the three RGB channels to normalize the

# image channel

normalize = gluon.data.vision.transforms.Normalize(

    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = gluon.data.vision.transforms.Compose([

    gluon.data.vision.transforms.RandomResizedCrop(224),

    gluon.data.vision.transforms.RandomFlipLeftRight(),

    gluon.data.vision.transforms.ToTensor(),

    normalize])

test_augs = gluon.data.vision.transforms.Compose([

    gluon.data.vision.transforms.Resize(256),

    gluon.data.vision.transforms.CenterCrop(224),

    gluon.data.vision.transforms.ToTensor(),

    normalize])

1.2. Defining and Initializing the Model

我們使用ResNet-18作為源模型,ResNet-18是在ImageNet數據集上預先訓練的。這里,我們指定pretrained=True以自動下載和加載預先訓練的模型參數。第一次使用時,需要從互聯網上下載模型參數。

pretrained_net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)

預先訓練的源模型實例包含兩個成員變量:features和output。前者包含模型的所有層,輸出層除外,后者是模型的輸出層。這一划分的主要目的是促進除輸出層之外的所有層的模型參數的微調。源模型的成員變量輸出如下所示。作為一個完全連接的層,它將ResNet最終的全局平均池層輸出轉換為ImageNet數據集上的1000個類輸出。

pretrained_net.output

Dense(512 -> 1000, linear)

然后構建一個新的神經網絡作為目標模型。它的定義方式與預先訓練的源模型相同,但最終輸出數量等於目標數據集中的類別數。在下面的代碼中,目標模型實例finetune_net的成員變量特征中的模型參數初始化為源模型對應層的模型參數。由於特征中的模型參數是通過對ImageNet數據集的預訓練得到的,所以它是足夠好的。因此,我們通常只需要使用較小的學習速率來“微調”這些參數。相比之下,成員變量輸出中的模型參數是隨機初始化的,通常需要更大的學習速率才能從頭開始學習。假設訓練實例中的學習率為 η,學習率為10η,更新成員變量輸出中的模型參數。

finetune_net = gluon.model_zoo.vision.resnet18_v2(classes=2)

finetune_net.features = pretrained_net.features

finetune_net.output.initialize(init.Xavier())

# The model parameters in output will be updated using a learning rate ten

# times greater

finetune_net.output.collect_params().setattr('lr_mult', 10)

1.3. Fine Tuning the Model

我們首先定義了一個訓練函數train_fine_tuning,它使用了微調,因此可以多次調用它。

def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5):

    train_iter = gluon.data.DataLoader(

        train_imgs.transform_first(train_augs), batch_size, shuffle=True)

    test_iter = gluon.data.DataLoader(

        test_imgs.transform_first(test_augs), batch_size)

    ctx = d2l.try_all_gpus()

    net.collect_params().reset_ctx(ctx)

    net.hybridize()

    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    trainer = gluon.Trainer(net.collect_params(), 'sgd', {

        'learning_rate': learning_rate, 'wd': 0.001})

    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx)

我們將訓練器實例中的學習率設置為一個較小的值,如0.01,以便對預訓練中獲得的模型參數進行微調。基於前面的設置,我們將使用10倍以上的學習率從頭開始訓練目標模型的輸出層參數。

train_fine_tuning(finetune_net, 0.01)

loss 0.518, train acc 0.890, test acc 0.927

634.3 examples/sec on [gpu(0), gpu(1)]

 

為了進行比較,我們定義了一個相同的模型,但將其所有模型參數初始化為隨機值。由於整個模型需要從頭開始訓練,所以我們可以使用更大的學習率。

scratch_net = gluon.model_zoo.vision.resnet18_v2(classes=2)

scratch_net.initialize(init=init.Xavier())

train_fine_tuning(scratch_net, 0.1)

loss 0.371, train acc 0.839, test acc 0.784

706.5 examples/sec on [gpu(0), gpu(1)]

 

正如您所看到的,由於參數的初始值更好,微調后的模型往往在同一時代獲得更高的精度。

2. Summary

  • Transfer learning migrates the knowledge learned from the source dataset to the target dataset. Fine tuning is a common technique for transfer learning.
  • The target model replicates all model designs and their parameters on the source model, except the output layer, and fine-tunes these parameters based on the target dataset. In contrast, the output layer of the target model needs to be trained from scratch.
  • Generally, fine tuning parameters use a smaller learning rate, while training the output layer from scratch can use a larger learning rate.

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM