pytorch-卷積基本網絡結構-提取網絡參數-初始化網絡參數


基本的卷積神經網絡

from torch import nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        layer1 = nn.Sequential() # 將網絡模型進行添加
        layer1.add_module('conv1', nn.Conv2d(3, 32, 3, 1, padding=1)) # nn.Conv
        layer1.add_module('relu1', nn.ReLU(True))
        layer1.add_module('pool1', nn.MaxPool2d(2, 2))
        self.layer1 = layer1

        layer2 = nn.Sequential()
        layer2.add_module('conv2', nn.Conv2d(32, 64, 3, 1, padding=1))
        layer2.add_module('relu2', nn.ReLU(True))
        layer2.add_module('pool2', nn.MaxPool2d(2, 2))
        self.layer2 = layer2

        layer3 = nn.Sequential()
        layer3.add_module('conv3', nn.Conv2d(64, 128, 3, 1, padding=1))
        layer3.add_module('relu3', nn.ReLU(True))
        layer3.add_module('pool3', nn.MaxPool2d(2, 2))
        self.layer3 = layer3

        layer4 = nn.Sequential()
        layer4.add_module('fc1', nn.Linear(2048, 512))
        layer4.add_module('fc_relu1', nn.ReLU(True))
        layer4.add_module('fc2', nn.Linear(512, 64))
        layer4.add_module('fc_relu2', nn.ReLU(True))
        layer4.add_module('fc3', nn.Linear(64, 10))
        self.layer4 = layer4

    def forward(self, x):
        conv1 = self.layer1(x)
        conv2 = self.layer2(conv1)
        conv3 = self.layer3(conv2)
        fc_input = conv3.view(conv3.size(0), -1)
        fc_out = self.layer4(fc_input)

        return fc_out

model = SimpleCNN()
# print(model) # 打印輸出網絡結構

提取前兩層網絡結構 

new_model = nn.Sequential(*list(model.children())[:2])  # 提取前兩層的網絡結構, 構造nn.Sequential網絡串接, * 表示將里面的內容一個個傳進去

提取所有的卷積層網絡

conv_model = nn.Sequential()
# 提取所有的卷積層操作
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        name = name.replace('.', '_')
        conv_model.add_module(name, layer)
print(conv_model)

打印卷積層的網絡名字

for param in model.named_parameters():
    print(param)

對權重參數進行初始化操作

from torch.nn import init
# 對權重參數進行初始化操作
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight.data)
        init.xavier_normal(m.weight.data)
        init.kaiming_normal(m.weight.data)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM