PyTorch LSTM的一個簡單例子:實現MNIST圖片分類


上一篇博客中,我們實現了用LSTM對單詞進行詞性判斷,本篇博客我們將實現用LSTM對MNIST圖片分類。MNIST圖片的大小為28*28,我們將其看成長度為28的序列,序列中的每個數據的維度是28,這樣我們就可以把它變成一個序列數據了。代碼如下,代碼中的模型搭建參考了文末的參考資料[1],其余部分參考了文末的參考資料[2]。

'''
本程序實現用LSTM對MNIST進行圖片分類
'''

import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt


# Hyper parameter
EPOCH = 1
LR = 0.001    # learning rate
BATCH_SIZE = 50

# Mnist digit dataset
train_data = torchvision.datasets.MNIST(
    root='/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/',    # mnist has been downloaded before, use it directly
    train=True,    # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=False,
)

# print(train_data.data.size())       # (60000, 28, 28)
# print(train_data.targets.size())    # (60000)
# plot one image
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# plt.title('{:d}'.format(train_data.targets[0]))
# plt.show()

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(
    root='/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/',
    train=False,  # this is training data
)
# print(test_data.data.size())       # (10000, 28, 28)
# print(test_data.targets.size())    # (10000)
# pick 2000 samples to speed up testing
test_x = test_data.data.type(torch.FloatTensor)[:2000]/255    # shape (2000, 28, 28), value in range(0,1)
test_y = test_data.targets[:2000]


class LSTMnet(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layer, n_class):
        super(LSTMnet, self).__init__()
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
        self.linear = nn.Linear(hidden_dim, n_class)

    def forward(self, x):                  # x's shape (batch_size, 序列長度, 序列中每個數據的長度)
        out, _ = self.lstm(x)              # out's shape (batch_size, 序列長度, hidden_dim)
        out = out[:, -1, :]                # 中間的序列長度取-1,表示取序列中的最后一個數據,這個數據長度為hidden_dim,
                                           # 得到的out的shape為(batch_size, hidden_dim)
        out = self.linear(out)             # 經過線性層后,out的shape為(batch_size, n_class)
        return out


model = LSTMnet(28, 64, 2, 10)             # 圖片大小28*28,lstm的每個隱藏層64個節點,2層隱藏層
if torch.cuda.is_available():
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# training and testing
for epoch in range(EPOCH):
    for iteration, (train_x, train_y) in enumerate(train_loader):    # train_x's shape (BATCH_SIZE,1,28,28)
        train_x = train_x.squeeze()        # after squeeze, train_x's shape (BATCH_SIZE,28,28),
                                           # 第一個28是序列長度,第二個28是序列中每個數據的長度。
        output = model(train_x)
        loss = criterion(output, train_y)  # cross entropy loss
        optimizer.zero_grad()              # clear gradients for this training step
        loss.backward()                    # backpropagation, compute gradients
        optimizer.step()                   # apply gradients

        if iteration % 100 == 0:
            test_output = model(test_x)
            predict_y = torch.max(test_output, 1)[1].numpy()
            accuracy = float((predict_y == test_y.numpy()).astype(int).sum()) / float(test_y.size(0))
            print('epoch:{:<2d} | iteration:{:<4d} | loss:{:<6.4f} | accuracy:{:<4.2f}'.format(epoch, iteration, loss, accuracy))


# print 10 predictions from test data
test_out = model(test_x[:10])
pred_y = torch.max(test_out, dim=1)[1].data.numpy()
print('The predict number is:')
print(pred_y)
print('The real number is:')
print(test_y[:10].numpy())

結果如下:

下圖為本文的神經網絡處理單張圖片的過程:

 

參考資料:

[1] 10分鍾快速入門PyTorch (6)

[2] 莫煩PyTorch教程系列:CNN卷積神經網絡


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM