1 導入實驗所需要的包
import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import torch.nn.functional as F from torch.utils.data import DataLoader,TensorDataset import numpy as np %matplotlib inline
2 下載MNIST數據集以及讀取數據
train_dataset = torchvision.datasets.MNIST( root = '../Datasets/MNIST', train = True, transform = transforms.ToTensor(), download = True, ) test_dataset = torchvision.datasets.MNIST( root = '../Datasets/MNIST', train = False, download = True, transform = transforms.ToTensor(), ) print(train_dataset.data.shape) print(train_dataset.targets.shape) device='cuda:0' train_loader = DataLoader(train_dataset,batch_size= 64,shuffle=False) test_loader = DataLoader(test_dataset,batch_size= 64,shuffle= True)
3 定義模型
class LinearNet(nn.Module): def __init__(self,num_input,num_hidden,num_output): super(LinearNet,self).__init__() self.linear1 = nn.Linear(num_input,num_hidden).to(device) self.linear2 =nn.Linear(num_hidden,num_output).to(device) self.relu = nn.ReLU() self.flatten = nn.Flatten() def forward(self,x): out = self.flatten(x) out = self.relu(self.linear1(out)) out = self.linear2(out) return out
4 參數初始化
num_input,num_hidden ,num_output = 784,256,10 net = LinearNet(num_input,num_hidden,num_output).to(device = 'cuda:0') for param in net.state_dict(): print(param) loss = nn.CrossEntropyLoss() num_epochs = 100 net = LinearNet(num_input,num_hidden,num_output) param_w = [net.linear1.weight,net.linear2.weight] param_b = [net.linear1.bias,net.linear2.bias] optimzer_w = torch.optim.SGD(param_w,lr=0.001,weight_decay=0.01) optimzer_b = torch.optim.Adam(param_b,lr=0.001)
5 定義訓練函數
def train(net,num_epochs): train_ls,test_ls = [],[] for epoch in range(num_epochs): ls = 0 for x ,y in train_loader: x,y = x.cuda(),y.cuda() y_pred = net(x) l = loss(y_pred,y) optimzer_w.zero_grad() optimzer_b.zero_grad() l.backward() optimzer_w.step() optimzer_b.step() ls += l.item() train_ls.append(ls) ls = 0 for x ,y in test_loader: x,y = x.cuda(),y.cuda() y_pred = net(x) l = loss(y_pred,y) l += l.item() ls += l.item() test_ls.append(ls) print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
6 開始訓練