超簡單!pytorch入門教程(三):構造一個小型CNN


torch.nn只接受mini-batch的輸入,也就是說我們輸入的時候是必須是好幾張圖片同時輸入。

例如:nn. Conv2d 允許輸入4維的Tensor:n個樣本 x n個色彩頻道 x 高度 x 寬度

 

#coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class Net(nn.Module):
    #定義Net的初始化函數,這個函數定義了該神經網絡的基本結構
    def __init__(self):
        super(Net, self).__init__() #復制並使用Net的父類的初始化方法,即先運行nn.Module的初始化函數
        self.conv1 = nn.Conv2d(1, 6, 5) # 定義conv1函數的是圖像卷積函數:輸入為圖像(1個頻道,即灰度圖),輸出為 6張特征圖, 卷積核為5x5正方形
        self.conv2 = nn.Conv2d(6, 16, 5)# 定義conv2函數的是圖像卷積函數:輸入為6張特征圖,輸出為16張特征圖, 卷積核為5x5正方形
        self.fc1   = nn.Linear(16*5*5, 120) # 定義fc1(fullconnect)全連接函數1為線性函數:y = Wx + b,並將16*5*5個節點連接到120個節點上。
        self.fc2   = nn.Linear(120, 84)#定義fc2(fullconnect)全連接函數2為線性函數:y = Wx + b,並將120個節點連接到84個節點上。
        self.fc3   = nn.Linear(84, 10)#定義fc3(fullconnect)全連接函數3為線性函數:y = Wx + b,並將84個節點連接到10個節點上。

    #定義該神經網絡的向前傳播函數,該函數必須定義,一旦定義成功,向后傳播函數也會自動生成(autograd)
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) #輸入x經過卷積conv1之后,經過激活函數ReLU(原來這個詞是激活函數的意思),使用2x2的窗口進行最大池化Max pooling,然后更新到x。
        x = F.max_pool2d(F.relu(self.conv2(x)), 2) #輸入x經過卷積conv2之后,經過激活函數ReLU,使用2x2的窗口進行最大池化Max pooling,然后更新到x。
        x = x.view(-1, self.num_flat_features(x)) #view函數將張量x變形成一維的向量形式,總特征數並不改變,為接下來的全連接作准備
        x = F.relu(self.fc1(x)) #輸入x經過全連接1,再經過ReLU激活函數,然后更新x
        x = F.relu(self.fc2(x)) #輸入x經過全連接2,再經過ReLU激活函數,然后更新x
        x = self.fc3(x) #輸入x經過全連接3,然后更新x
        return x

    #使用num_flat_features函數計算張量x的總特征量(把每個數字都看出是一個特征,即特征總量),比如x是4*2*2的張量,那么它的特征總量就是16。
    def num_flat_features(self, x):
        size = x.size()[1:] # 這里為什么要使用[1:],是因為pytorch只接受批輸入,也就是說一次性輸入好幾張圖片,那么輸入數據張量的維度自然上升到了4維。【1:】讓我們把注意力放在后3維上面
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
net

# 以下代碼是為了看一下我們需要訓練的參數的數量
print net
params = list(net.parameters())

k=0
for i in params:
    l =1
    print "該層的結構:"+str(list(i.size()))
    for j in i.size():
        l *= j
    print "參數和:"+str(l)
    k = k+l

print "總參數和:"+ str(k)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM