使用Pytorch手動訓練VGG11


在前面的博客中我們提到如何用pytorch搭建一個VGG11網絡框架;

詳見使用Pytorch搭建VGG網絡——以VGG11為例

 

在本博客中,我們將使用之前搭建的VGG11網絡,同時對其進行手動訓練,使我們可以更好的理解模型建立和訓練的過程;

主要內容:

  • 數據集和目錄結構——使用數字手寫辨識來訓練VGG11
  • 編碼部分:

    1)數據集准備;

    2)訓練和驗證模型;

    3)優化器

    4)驗證每個epoch的精度;

  • 分析訓練的損失和精度;
  • 在圖像上訓練和測試數據

一、數據集和目錄結構

數據集:手寫體 Digit MNIST

使用torchvision.dataset模塊加載;

目錄結構:

├── input
│   └── test_data
│       ├── eight.jpg
│       ├── two.jpg
│       └── zero.jpg
├── outputs
│   ├── accuracy.jpg
│   └── loss.jpg
|   ...
├── src
│   ├── data
│   │   └── MNIST
│   ...
│   ├── models.py
│   ├── test.py
│   └── train.py

 

二、 編碼部分:

1)  網絡模型腳本:VGG11模型的加載,見 使用Pytorch搭建VGG網絡——以VGG11為例

2)訓練腳本train.py的編寫:

2.1)相關包的導入:

 1 import torch
 2 import torchvision
 3 import torchvision.transforms as transforms
 4 import matplotlib.pyplot as plt
 5 import matplotlib
 6 import torch.nn as nn
 7 import torch.optim as optim
 8 from tqdm import tqdm
 9 from models import VGG11
10 matplotlib.style.use('ggplot')

2.2)定義模型的參數及設備:

1 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2 print(f"[INFO]: Computation device: {device}")
3 epochs = 10
4 batch_size = 32

出現了OOM(out of memory)錯誤的話,可以減少batch到16、8或者4,來迎合你的GPU;

2.3) 圖像變換

 1 # our transforms will differ a bit from the VGG paper
 2 # as we are using the MNIST dataset, so, we will directly resize...
 3 # ... the images to 224x224 and not crop them and we will not use...
 4 # ... any random flippings also
 5 train_transform = transforms.Compose(
 6     [transforms.Resize((224, 224)),
 7      transforms.ToTensor(),
 8      transforms.Normalize(mean=(0.5), std=(0.5))])
 9 valid_transform = transforms.Compose(
10     [transforms.Resize((224, 224)),
11      transforms.ToTensor(),
12      transforms.Normalize(mean=(0.5), std=(0.5))])

注:原始論文中有圖像的翻轉flipping,在我們的訓練中是不需要的;

Line6把圖像進行resize()

line7把圖像變成tensor

line7標准化處理;

————————————————————————————————

2.4)數據的加載

接下來准備訓練和驗證數據集以及data loaders。

 1 # training dataset and data loader
 2 train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
 3                                              download=True, 
 4                                              transform=train_transform)
 5 train_dataloader = torch.utils.data.DataLoader(train_dataset, 
 6                                                batch_size=batch_size,
 7                                                shuffle=True)
 8 # validation dataset and dataloader
 9 valid_dataset = torchvision.datasets.MNIST(root='./data', train=False,
10                                            download=True, 
11                                            transform=valid_transform)
12 valid_dataloader = torch.utils.data.DataLoader(valid_dataset, 
13                                              batch_size=batch_size,
14                                              shuffle=False)

2.5)模型的初始化、損失函數以及優化器

 1 # instantiate the model
 2 model = VGG11(in_channels=1, num_classes=10).to(device)
 3 # total parameters and trainable parameters
 4 total_params = sum(p.numel() for p in model.parameters())
 5 print(f"[INFO]: {total_params:,} total parameters.")
 6 total_trainable_params = sum(
 7     p.numel() for p in model.parameters() if p.requires_grad)
 8 print(f"[INFO]: {total_trainable_params:,} trainable parameters.")
 9 # the loss function
10 criterion = nn.CrossEntropyLoss()
11 # the optimizer
12 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, 
13                       weight_decay=0.0005)
  • 處理的MNIS是灰度圖像,所以輸入通道為1;輸出的類別是10;
  • 使用的是交叉熵損失函數;SGD用於進行參數更新;

2.6)訓練函數

 1 # training
 2 def train(model, trainloader, optimizer, criterion):
 3     model.train()
 4     print('Training')
 5     train_running_loss = 0.0
 6     train_running_correct = 0
 7     counter = 0
 8     for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
 9         counter += 1
10 
11         image, labels = data
12         image = image.to(device)
13         labels = labels.to(device)
14 
15         optimizer.zero_grad()
16         # forward pass
17         outputs = model(image)
18         # calculate the loss
19         loss = criterion(outputs, labels)
20         train_running_loss += loss.item()
21         # calculate the accuracy
22         _, preds = torch.max(outputs.data, 1)
23         train_running_correct += (preds == labels).sum().item()
24         loss.backward()
25         optimizer.step()
26     
27     epoch_loss = train_running_loss / counter
28     epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
29     return epoch_loss, epoch_acc
  • 迭代訓練數據加載器,提取標簽和圖像;
  • 加載圖像和標簽到計算設備上;
  • 前向傳播,計算損失和精度;后向傳播損失;
  • 返回目前epoch的損失和精度;

2.8)驗證函數:

驗證函數有一些不同。對每個epoch中,我們將會計算損失和精度;

我們將對每個類別計算精度,評估在每個epoch中我們的模型表現力;

 1 # validation
 2 def validate(model, testloader, criterion):
 3     model.eval()
 4     
 5     # we need two lists to keep track of class-wise accuracy
 6     class_correct = list(0. for i in range(10))
 7     class_total = list(0. for i in range(10))
 8     print('Validation')
 9     valid_running_loss = 0.0
10     valid_running_correct = 0
11     counter = 0
12     with torch.no_grad():
13         for i, data in tqdm(enumerate(testloader), total=len(testloader)):
14             counter += 1
15             
16             image, labels = data
17             image = image.to(device)
18             labels = labels.to(device)
19             # forward pass
20             outputs = model(image)
21             # calculate the loss
22             loss = criterion(outputs, labels)
23             valid_running_loss += loss.item()
24             # calculate the accuracy
25             _, preds = torch.max(outputs.data, 1)
26             valid_running_correct += (preds == labels).sum().item()
27             # calculate the accuracy for each class
28             correct  = (preds == labels).squeeze()
29             for i in range(len(preds)):
30                 label = labels[i]
31                 class_correct[label] += correct[i].item()
32                 class_total[label] += 1
33         
34     epoch_loss = valid_running_loss / counter
35     epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
36     # print the accuracy for each class after evey epoch
37     # the values should increase as the training goes on
38     print('\n')
39     for i in range(10):
40         print(f"Accuracy of digit {i}: {100*class_correct[i]/class_total[i]}")
41     return epoch_loss, epoch_acc

 

2.9)訓練過程:

 1 # start the training
 2 # lists to keep track of losses and accuracies
 3 train_loss, valid_loss = [], []
 4 train_acc, valid_acc = [], []
 5 for epoch in range(epochs):
 6     print(f"[INFO]: Epoch {epoch+1} of {epochs}")
 7 
 8     train_epoch_loss, train_epoch_acc = train(model, train_dataloader, 
 9                                               optimizer, criterion)
10     valid_epoch_loss, valid_epoch_acc = validate(model, valid_dataloader,  
11                                                  criterion)
12     train_loss.append(train_epoch_loss)
13     valid_loss.append(valid_epoch_loss)
14     train_acc.append(train_epoch_acc)
15     valid_acc.append(valid_epoch_acc)
16 
17     print('\n')
18     print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
19     print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
20 
21     print('-'*50)
  • 把訓練和驗證的損失和精度存儲到train_loss、valid_loss以及train_acc和valid_acc中;
  • 在每個epoch中,我們打印訓練和損失的指標;

最后的步驟存儲在訓練的模型中,繪制出損失和精度圖;

 

三、分析訓練的損失和精度;

 

 圖1 精度圖

 

  圖2 損失圖

 

四、對新的圖像進行推斷

1 import torch
2 import cv2
3 import glob as glob
4 import torchvision.transforms as transforms
5 import numpy as np
6 from models import VGG11

加載訓練的權重,定義Transforms

 1 # inferencing on CPU
 2 device = 'cpu'
 3 # initialize the VGG11 model
 4 model = VGG11(in_channels=1, num_classes=10)
 5 # load the model checkpoint
 6 checkpoint = torch.load('../outputs/model.pth')
 7 # load the trained weights
 8 model.load_state_dict(checkpoint['model_state_dict'])
 9 model.to(device)
10 model.eval()
11 # simple image transforms
12 transform = transforms.Compose([
13     transforms.ToPILImage(),
14     transforms.Resize((224, 224)),
15     transforms.ToTensor(),
16     transforms.Normalize(mean=[0.5],
17                          std=[0.5])
18 ])

讀圖像並將其傳入模型中

 1 # get all the test images path
 2 image_paths = glob.glob('../input/test_data/*.jpg')
 3 for i, image_path in enumerate(image_paths):
 4     orig_img = cv2.imread(image_path)
 5     # convert to grayscale to make the image single channel
 6     image = cv2.cvtColor(orig_img, cv2.COLOR_BGR2GRAY)
 7     image = transform(image)
 8     # add one extra batch dimension
 9     image = image.unsqueeze(0).to(device)
10     # forward pass the image through the model
11     outputs = model(image)
12     # get the index of the highest score
13     # the highest scoring indicates the label for the Digit MNIST dataset
14     label = np.array(outputs.detach()).argmax()
15     print(f"{image_path.split('/')[-1].split('.')[0]}: {label}")
16     # put the predicted label on the original image
17     cv2.putText(orig_img, str(label), (15, 50), cv2.FONT_HERSHEY_SIMPLEX, 
18                 2, (0, 255, 0), 2)
19     # show and save the resutls
20     cv2.imshow('Result', orig_img)
21     cv2.waitKey(0)
22     cv2.imwrite(f"../outputs/result_{i}.jpg", orig_img)

總結:

本文我們主要訓練了一個VGG11的神經網絡從手寫體MNIST數據集中;

我們以模型的初始化、訓練模型和觀察模型的精度即損失為主線進行了說明;

最后在新的數據集上對模型進行了驗證。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM