一、 導入
1 import torch 2 from torch import nn 3 from d2l import torch as d2l 4 5 batch_size = 256 6 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
二、初始化參數
# PyTorch不會隱式地調整輸入的形狀。因此, # 我們在線性層前定義了展平層(flatten),來調整網絡輸入的形狀 # nn.Flatten() 將任何維度的tensor改成一個2d的tensor,第0維度保留,剩下的維度全部展成一個向量 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weight, std=0.01) net.apply(init_weights);
三、Softmax的實現
1 loss = nn.CrossEntropyLoss()
四、優化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
五、訓練
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)