pyTorch使用mnist數據集實現手寫數字識別


使用mnist數據集實現手寫數字識別是入門必做吧。這里使用pyTorch框架進行簡單神經網絡的搭建。

首先導入需要的包。

1 import torch
2 import torch.nn as nn
3 import torch.utils.data as Data
4 import torchvision

 

接下來需要下載mnist數據集。我們創建train_data。使用torchvision.datasets.MNIST進行數據集的下載。

1 train_data = torchvision.datasets.MNIST(
2     root='./mnist/',   #下載到該目錄下
3     train=True,                                     #為訓練數據
4     transform=torchvision.transforms.ToTensor(),    #將其裝換為tensor的形式
5     download=True, #第一次設置為true表示下載,下載完成后,將其置成false
6 )

 之后將其導入data_loader中,這個數據加載類會自動幫我們進行數據集的切片。

 1 train_data = torchvision.datasets.MNIST(
 2     root='./mnist',
 3     train=True,
 4     transform=torchvision.transforms.ToTensor(),
 5     download=False
 6 )
 7 train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0)
 8 test_data = torchvision.datasets.MNIST(
 9     root='./mnist',
10     train=False,
11     transform=torchvision.transforms.ToTensor(),
12 )
13 test_loader = Data.DataLoader(dataset=test_data, batch_size=32, shuffle=False, num_workers=0)
14 test_num = len(test_data)

之后開始定義我們的模型,由於minist數據集是灰度圖像,並且圖片的size都是(28, 28, 1),所以輸入圖片的時候不需要進行額外的修改。

 1 class Net(nn.Module):
 2     def __init__(self):
 3         super(Net, self).__init__()
 4         self.conv1 = nn.Sequential(#(1, 28, 28)
 5             nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),#(16, 28, 28)
 6             nn.ReLU(),#(16, 28, 28)
 7             nn.MaxPool2d(kernel_size=2)#(16, 14, 14)
 8         )
 9         self.conv2 = nn.Sequential(
10             nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),#(32, 14, 14)
11             nn.ReLU(),#(32, 14, 14)
12             nn.MaxPool2d(kernel_size=2)#(32, 7, 7)
13         )
14         self.fc = nn.Linear(32 * 7 * 7, 10)
15     def forward(self, x):
16         x = self.conv1(x)
17         x = self.conv2(x)
18         x = x.view(x.size(0), -1)
19         x = self.fc(x)
20         return x

特別注意在最后傳入全連接層時,最好自己將x的size改變以確保不會因為自適應而造成錯誤。因為在傳入全連接層時會默認壓縮成二維,例如[1, 2, 3, 4]會被壓縮成[1*2, 3*4]。

之后開始訓練。

 1 net = Net()
 2 loss_fn = nn.CrossEntropyLoss()
 3 optim = torch.optim.Adam(net.parameters(), lr = 0.001)
 4 
 5 save_path = './mnist.pth'
 6 best_acc = 0.0
 7 for epoch in range(3):
 8 
 9     net.train()
10     running_loss = 0.0
11     for step, data in enumerate(train_loader, start=0):
12         images, labels = data
13         optim.zero_grad()
14         logits = net(images)
15         loss = loss_fn(logits, labels)
16         loss.backward()
17         optim.step()
18 
19 
20         running_loss += loss.item()
21         rate = (step+1)/len(train_loader)
22         a = "*" * int(rate * 50)
23         b = "." * int((1 - rate) * 50)
24         print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
25     print()
26 
27     net.eval()
28     acc = 0.0
29     with torch.no_grad():
30         for data_test in test_loader:
31             test_images, test_labels = data_test
32             outputs = net(test_images)
33             predict_y = torch.max(outputs, dim=1)[1]#torch.max返回兩個數值,一個是最大值,一個是最大值的下標
34             acc += (predict_y == test_labels).sum().item()
35         test_accurate = acc / test_num
36         if test_accurate > best_acc:
37             best_acc = test_accurate
38             torch.save(net.state_dict(), save_path)
39         print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
40               (epoch + 1, running_loss / step, test_accurate))
41 
42 print('Finished Training')

在完成訓練后,訓練的權重會保存在所設置路徑下的文件中,進行預測的時候,建立模型,載入權重,照一張數字的圖片,對其進行裁剪,灰度等操作之后加載入模型進行預測。

 1 from PIL import Image
 2 import  matplotlib.pyplot as plt
 3 from torchvision import transforms
 4 import torch
 5 from model import Net
 6 
 7 img = Image.open("./YLY2@}8UMGLW37S$)NCVZ23.png")
 8 
 9 plt.imshow(img)
10 
11 # [N, C, H, W]
12 
13 train_transform = transforms.Compose([
14         transforms.Grayscale(),
15         transforms.Resize((28, 28)),
16         transforms.ToTensor(),
17 ])
18 
19 img = train_transform(img)
20 # expand batch dimension
21 img = torch.unsqueeze(img, dim=0)
22 
23 # create model
24 model = Net()
25 # load model weights
26 model_weight_path = "./mnist.pth"
27 model.load_state_dict(torch.load(model_weight_path))
28 
29 index_to_class = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
30 
31 
32 model.eval()
33 with torch.no_grad():
34     # predict class
35     y = model(img)
36     #print(y.size())
37     output = torch.squeeze(y)
38     #print(output)
39     predict = torch.softmax(output, dim=0)
40     #print(predict)
41     predict_cla = torch.argmax(predict).numpy()
42     #print(predict_cla)
43 print(index_to_class[predict_cla], predict[predict_cla].numpy())
44 plt.show()

需要注意的是,載入模型的圖片必須多一個維度batch,所以我們用img = torch.unsqueeze(img, dim=0)在圖片的開頭增加一個batch維度。

之后載入圖片,得到輸出,將輸出的batch維度壓縮掉,使用softmax函數得到概率分布,再用argmax函數得到最大值的下標,打印最大值所對應的類別及其概率。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM