多層全連接網絡實現手寫數字識別(PyTorch)


具體細節見深度學習之PyTorch(廖星宇)

#基於深度神經網絡的手寫數字識別的PyTorch實現
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms #提供預處理

#三層神經網絡
class simpleNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(simpleNet, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

#激活函數
class Activation_Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Activation_Net, self).__init__()
        #Sequential序列模型,像堆積木那樣將各層網絡堆起來
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.ReLU(True)) 
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
        
#batch-normalization(將標准化的過程應用到每層神經網絡)
class Batch_Net(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Batch_Net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x) 
        x = self.layer3(x)
        return x

#對超參數和數據進行處理
batch_size = 64
learning_rate = 1e-2
num_epoches = 20

data_tf = transforms.Compose(
    [transforms.ToTensor(), #標准化
    transforms.Normalize([0.5], [0.5])] #歸一化
)

#下載訓練集MNIST手寫訓練集
train_dataset = datasets.MNIST(root = './data', train = True, transform = data_tf, download = True)
test_dataset = datasets.MNIST(root = './data', train = False, transform = data_tf)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

#定義損失函數和優化器
model = Batch_Net(28 * 28, 300, 100, 10)

if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = learning_rate)

#訓練網絡
for epoch in range(num_epoches):
    model.train()
    for data in train_loader: #每次取一個batch_size張圖片
        img, label = data #img.size = 128 * 1 * 28 * 28
        img = img.view(img.size(0), -1) #展開成 128 * 784
        
        if torch.cuda.is_available():
            img = img.cuda()
            label = label.cuda()
        else:
            img = Variable(img)
            label = Variable(label)
        
        #前向傳播
        out = model(img)
        loss = criterion(out, label)
        
        #反向傳播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print('Epoch:{}, Loss:{:.4f}'.format(epoch + 1, loss.data.item())) #隨機梯度下降法,epoch一次,梯度下降好多次

#測試網絡
model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:
    img, label = data
    img = img.view(img.size(0), -1)
    
    if torch.cuda.is_available():
        img = Variable(img, volatile = True).cuda()
        label = Variable(label, volatile = True).cuda()
    else:
        img = Variable(img, volatile = True)
        label = Variable(label, volatile = True)
    
    out = model(img)
    loss = criterion(out, label)
    
    eval_loss += loss.item() * label.size(0) #label.size(0) = 128
    _, predict = torch.max(out, 1)
    num_correct = (predict == label).sum()
    eval_acc += num_correct.item()

print('Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss / len(test_dataset), eval_acc / len(test_dataset)))

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM