1、導入庫
import torch import torch.nn as nn
2、搭建卷積神經網絡
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=3) self.linear = nn.Linear(5766,10) self.relu = nn.ReLU(inplace=True) self.maxpooling = nn.MaxPool2d((2,2)) def forward(self, x): x = self.conv(x) #print (x.shape) x = self.maxpooling(x) #print (x.shape) x = self.relu(x) x = x.view(1,-1) #print (x.shape) x = self.linear(x) return x
對於新手來說,可以先熟悉pytorch的格式。網絡定義一般由兩部分組成,
def __init__(self): 用來定義網絡節點參數;
def forward(self, x):
將節點連接成圖。
卷積計算規則,對我們輸入形狀(1,1,64,64),四個維度分別是(batch,channel,height,width)。batch:一次訓練的批次,channel圖像通道(比如RGB,channel = 3)。height,width分別指圖像的高和寬。
new_height= (height - kernel_size + 2×padding)/(stride[0])+1;padding默認為0,意思的在周圍補一圈零; stride默認為1,因此。
new_height = new_width = (64 - 3)/(1) + 1 = 62。
由於輸出通道數為6,所以通過卷積層后維度(1,6,62,62)
經過pooling后,(1,6,31,31)
x.view(1,-1):把x伸縮為(1,?)的維度,即(1,1×6×31×31)=(1,5766)
nn.Linear(5766,10),把(1,5766)映射為(1,10)的維度。這樣整個網絡其實輸入(1,1,64,64),輸出(1,10)
3、添加訓練數據
if __name__ =='__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = Net().to(device) optimizer = torch.optim.Adam(net.parameters()) criterion = nn.MSELoss() net.train() epoch = 100 input = torch.randn(1,1,64,64).cuda() output = torch.ones(1,10).cuda() batch = 32
optimizer:是優化器,即所謂的反向傳播算法。 criterion = nn.MSELoss()定義損失函數。
input = torch.randn(1,1,64,64).cuda() output = torch.ones(1,10).cuda()。定義訓練樣本,注意如果在gpu中訓練,在pytorch中需要.cuda()把數據從cpu中導入到gpu中
網絡的功能是給定隨機噪聲向量,輸出是逼近1的單位向量。
4、訓練:
for step in range(epoch): prediction = net(input) loss = criterion(prediction, output) optimizer.zero_grad() #消除優化器梯度 loss.backward() optimizer.step() if step % 10 == 0: print("EPOCHS: {},Loss:{:4f}".format(step, loss))
loss.backward() 指自動求導 optimizer.step() 指根據自動求導反向傳播優化參數。
5、我們可以輸出樣本看看結果:
print (prediction.cpu().detach().numpy())
#返回一個新的 從當前圖中分離的 Variable。 print (output.cpu().numpy())
注意輸出結果時必須對張量.cpu()把張量從gpu轉到cpu中。
對於計算圖中的張量(比如x,prediction),必須加.detach()從計算圖中導出才能轉化成numpy。
輸出結果:
[[1.0127475 0.98897606 1.002695 0.9881151 1.0137383 1.0051517
1.0140573 1.0051212 1.0088345 0.9978328 ]]
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
(若有新的見解請加以批評指正)