在前面的博客中我們提到如何用pytorch搭建一個VGG11網絡框架;
在本博客中,我們將使用之前搭建的VGG11網絡,同時對其進行手動訓練,使我們可以更好的理解模型建立和訓練的過程;
主要內容:
- 數據集和目錄結構——使用數字手寫辨識來訓練VGG11
- 編碼部分:
1)數據集准備;
2)訓練和驗證模型;
3)優化器
4)驗證每個epoch的精度;
- 分析訓練的損失和精度;
- 在圖像上訓練和測試數據
一、數據集和目錄結構
數據集:手寫體 Digit MNIST
使用torchvision.dataset模塊加載;
目錄結構:
├── input │ └── test_data │ ├── eight.jpg │ ├── two.jpg │ └── zero.jpg ├── outputs │ ├── accuracy.jpg │ └── loss.jpg | ... ├── src │ ├── data │ │ └── MNIST │ ... │ ├── models.py │ ├── test.py │ └── train.py
二、 編碼部分:
1) 網絡模型腳本:VGG11模型的加載,見 使用Pytorch搭建VGG網絡——以VGG11為例
2)訓練腳本train.py的編寫:
2.1)相關包的導入:
1 import torch 2 import torchvision 3 import torchvision.transforms as transforms 4 import matplotlib.pyplot as plt 5 import matplotlib 6 import torch.nn as nn 7 import torch.optim as optim 8 from tqdm import tqdm 9 from models import VGG11 10 matplotlib.style.use('ggplot')
2.2)定義模型的參數及設備:
1 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 2 print(f"[INFO]: Computation device: {device}") 3 epochs = 10 4 batch_size = 32
出現了OOM(out of memory)錯誤的話,可以減少batch到16、8或者4,來迎合你的GPU;
2.3) 圖像變換
1 # our transforms will differ a bit from the VGG paper 2 # as we are using the MNIST dataset, so, we will directly resize... 3 # ... the images to 224x224 and not crop them and we will not use... 4 # ... any random flippings also 5 train_transform = transforms.Compose( 6 [transforms.Resize((224, 224)), 7 transforms.ToTensor(), 8 transforms.Normalize(mean=(0.5), std=(0.5))]) 9 valid_transform = transforms.Compose( 10 [transforms.Resize((224, 224)), 11 transforms.ToTensor(), 12 transforms.Normalize(mean=(0.5), std=(0.5))])
注:原始論文中有圖像的翻轉flipping,在我們的訓練中是不需要的;
Line6把圖像進行resize()
line7把圖像變成tensor
line7標准化處理;
————————————————————————————————
2.4)數據的加載
接下來准備訓練和驗證數據集以及data loaders。
1 # training dataset and data loader 2 train_dataset = torchvision.datasets.MNIST(root='./data', train=True, 3 download=True, 4 transform=train_transform) 5 train_dataloader = torch.utils.data.DataLoader(train_dataset, 6 batch_size=batch_size, 7 shuffle=True) 8 # validation dataset and dataloader 9 valid_dataset = torchvision.datasets.MNIST(root='./data', train=False, 10 download=True, 11 transform=valid_transform) 12 valid_dataloader = torch.utils.data.DataLoader(valid_dataset, 13 batch_size=batch_size, 14 shuffle=False)
2.5)模型的初始化、損失函數以及優化器
1 # instantiate the model 2 model = VGG11(in_channels=1, num_classes=10).to(device) 3 # total parameters and trainable parameters 4 total_params = sum(p.numel() for p in model.parameters()) 5 print(f"[INFO]: {total_params:,} total parameters.") 6 total_trainable_params = sum( 7 p.numel() for p in model.parameters() if p.requires_grad) 8 print(f"[INFO]: {total_trainable_params:,} trainable parameters.") 9 # the loss function 10 criterion = nn.CrossEntropyLoss() 11 # the optimizer 12 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, 13 weight_decay=0.0005)
- 處理的MNIS是灰度圖像,所以輸入通道為1;輸出的類別是10;
- 使用的是交叉熵損失函數;SGD用於進行參數更新;
2.6)訓練函數
1 # training 2 def train(model, trainloader, optimizer, criterion): 3 model.train() 4 print('Training') 5 train_running_loss = 0.0 6 train_running_correct = 0 7 counter = 0 8 for i, data in tqdm(enumerate(trainloader), total=len(trainloader)): 9 counter += 1 10 11 image, labels = data 12 image = image.to(device) 13 labels = labels.to(device) 14 15 optimizer.zero_grad() 16 # forward pass 17 outputs = model(image) 18 # calculate the loss 19 loss = criterion(outputs, labels) 20 train_running_loss += loss.item() 21 # calculate the accuracy 22 _, preds = torch.max(outputs.data, 1) 23 train_running_correct += (preds == labels).sum().item() 24 loss.backward() 25 optimizer.step() 26 27 epoch_loss = train_running_loss / counter 28 epoch_acc = 100. * (train_running_correct / len(trainloader.dataset)) 29 return epoch_loss, epoch_acc
- 迭代訓練數據加載器,提取標簽和圖像;
- 加載圖像和標簽到計算設備上;
- 前向傳播,計算損失和精度;后向傳播損失;
- 返回目前epoch的損失和精度;
2.8)驗證函數:
驗證函數有一些不同。對每個epoch中,我們將會計算損失和精度;
我們將對每個類別計算精度,評估在每個epoch中我們的模型表現力;
1 # validation 2 def validate(model, testloader, criterion): 3 model.eval() 4 5 # we need two lists to keep track of class-wise accuracy 6 class_correct = list(0. for i in range(10)) 7 class_total = list(0. for i in range(10)) 8 print('Validation') 9 valid_running_loss = 0.0 10 valid_running_correct = 0 11 counter = 0 12 with torch.no_grad(): 13 for i, data in tqdm(enumerate(testloader), total=len(testloader)): 14 counter += 1 15 16 image, labels = data 17 image = image.to(device) 18 labels = labels.to(device) 19 # forward pass 20 outputs = model(image) 21 # calculate the loss 22 loss = criterion(outputs, labels) 23 valid_running_loss += loss.item() 24 # calculate the accuracy 25 _, preds = torch.max(outputs.data, 1) 26 valid_running_correct += (preds == labels).sum().item() 27 # calculate the accuracy for each class 28 correct = (preds == labels).squeeze() 29 for i in range(len(preds)): 30 label = labels[i] 31 class_correct[label] += correct[i].item() 32 class_total[label] += 1 33 34 epoch_loss = valid_running_loss / counter 35 epoch_acc = 100. * (valid_running_correct / len(testloader.dataset)) 36 # print the accuracy for each class after evey epoch 37 # the values should increase as the training goes on 38 print('\n') 39 for i in range(10): 40 print(f"Accuracy of digit {i}: {100*class_correct[i]/class_total[i]}") 41 return epoch_loss, epoch_acc
2.9)訓練過程:
1 # start the training 2 # lists to keep track of losses and accuracies 3 train_loss, valid_loss = [], [] 4 train_acc, valid_acc = [], [] 5 for epoch in range(epochs): 6 print(f"[INFO]: Epoch {epoch+1} of {epochs}") 7 8 train_epoch_loss, train_epoch_acc = train(model, train_dataloader, 9 optimizer, criterion) 10 valid_epoch_loss, valid_epoch_acc = validate(model, valid_dataloader, 11 criterion) 12 train_loss.append(train_epoch_loss) 13 valid_loss.append(valid_epoch_loss) 14 train_acc.append(train_epoch_acc) 15 valid_acc.append(valid_epoch_acc) 16 17 print('\n') 18 print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}") 19 print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}") 20 21 print('-'*50)
- 把訓練和驗證的損失和精度存儲到train_loss、valid_loss以及train_acc和valid_acc中;
- 在每個epoch中,我們打印訓練和損失的指標;
最后的步驟存儲在訓練的模型中,繪制出損失和精度圖;
三、分析訓練的損失和精度;
圖1 精度圖
圖2 損失圖
四、對新的圖像進行推斷
1 import torch 2 import cv2 3 import glob as glob 4 import torchvision.transforms as transforms 5 import numpy as np 6 from models import VGG11
加載訓練的權重,定義Transforms
1 # inferencing on CPU 2 device = 'cpu' 3 # initialize the VGG11 model 4 model = VGG11(in_channels=1, num_classes=10) 5 # load the model checkpoint 6 checkpoint = torch.load('../outputs/model.pth') 7 # load the trained weights 8 model.load_state_dict(checkpoint['model_state_dict']) 9 model.to(device) 10 model.eval() 11 # simple image transforms 12 transform = transforms.Compose([ 13 transforms.ToPILImage(), 14 transforms.Resize((224, 224)), 15 transforms.ToTensor(), 16 transforms.Normalize(mean=[0.5], 17 std=[0.5]) 18 ])
讀圖像並將其傳入模型中
1 # get all the test images path 2 image_paths = glob.glob('../input/test_data/*.jpg') 3 for i, image_path in enumerate(image_paths): 4 orig_img = cv2.imread(image_path) 5 # convert to grayscale to make the image single channel 6 image = cv2.cvtColor(orig_img, cv2.COLOR_BGR2GRAY) 7 image = transform(image) 8 # add one extra batch dimension 9 image = image.unsqueeze(0).to(device) 10 # forward pass the image through the model 11 outputs = model(image) 12 # get the index of the highest score 13 # the highest scoring indicates the label for the Digit MNIST dataset 14 label = np.array(outputs.detach()).argmax() 15 print(f"{image_path.split('/')[-1].split('.')[0]}: {label}") 16 # put the predicted label on the original image 17 cv2.putText(orig_img, str(label), (15, 50), cv2.FONT_HERSHEY_SIMPLEX, 18 2, (0, 255, 0), 2) 19 # show and save the resutls 20 cv2.imshow('Result', orig_img) 21 cv2.waitKey(0) 22 cv2.imwrite(f"../outputs/result_{i}.jpg", orig_img)
總結:
本文我們主要訓練了一個VGG11的神經網絡從手寫體MNIST數據集中;
我們以模型的初始化、訓練模型和觀察模型的精度即損失為主線進行了說明;
最后在新的數據集上對模型進行了驗證。