1.使用apply()
舉例說明:
- Encoder :設計的編碼其模型
- weights_init(): 用來初始化模型
- model.apply():實現初始化
# coding:utf-8 from torch import nn def weights_init(mod): """設計初始化函數""" classname=mod.__class__.__name__ # 返回傳入的module類型 print(classname) if classname.find('Conv')!= -1: #這里的Conv和BatchNnorm是torc.nn里的形式 mod.weight.data.normal_(0.0,0.02) elif classname.find('BatchNorm')!= -1: mod.weight.data.normal_(1.0,0.02) #bn層里初始化γ,服從(1,0.02)的正態分布 mod.bias.data.fill_(0) #bn層里初始化β,默認為0 class Encoder(nn.Module): def __init__(self, input_size, input_channels, base_channnes, z_channels): super(Encoder, self).__init__() # input_size必須為16的倍數 assert input_size % 16 == 0, "input_size has to be a multiple of 16" models = nn.Sequential() models.add_module('Conv2_{0}_{1}'.format(input_channels, base_channnes), nn.Conv2d(input_channels, base_channnes, 4, 2, 1, bias=False)) models.add_module('LeakyReLU_{0}'.format(base_channnes), nn.LeakyReLU(0.2, inplace=True)) # 此時圖片大小已經下降一倍 temp_size = input_size/2 # 直到特征圖高寬為4 # 目的是保證無論輸入什么大小的圖片,經過這幾層后特征圖大小為4*4 while temp_size > 4 : models.add_module('Conv2_{0}_{1}'.format(base_channnes, base_channnes*2), nn.Conv2d(base_channnes, base_channnes*2, 4, 2, 1, bias=False)) models.add_module('BatchNorm2d_{0}'.format(base_channnes*2), nn.BatchNorm2d(base_channnes*2)) models.add_module('LeakyReLU_{0}'.format(base_channnes*2), nn.LeakyReLU(0.2, inplace=True)) base_channnes *= 2 temp_size /= 2 # 特征圖高寬為4后面則添加上最后一層 # 讓輸出為1*1 models.add_module('Conv2_{0}_{1}'.format(base_channnes, z_channels), nn.Conv2d(base_channnes, z_channels, 4, 1, 0, bias=False)) self.models = models def forward(self, x): x = self.models(x) return x if __name__ == '__main__': e = Encoder(256, 3, 64, 100) # 對e模型中的每個module和其本身都會調用一次weights_init函數,mod參數的值即這些module e.apply(weights_init) # 根據名字來查看參數 for name, param in e.named_parameters(): print(name) # 舉個例子看看是否按照設計進行初始化 # 可見BatchNorm2d的weight是正態分布形的參數,bias參數都是0 if name == 'models.BatchNorm2d_128.weight' or name == 'models.BatchNorm2d_128.bias': print(param)
返回:
# 返回的是依次傳入初始化函數的module Conv2d LeakyReLU Conv2d BatchNorm2d LeakyReLU Conv2d BatchNorm2d LeakyReLU Conv2d BatchNorm2d LeakyReLU Conv2d BatchNorm2d LeakyReLU Conv2d BatchNorm2d LeakyReLU Conv2d Sequential Encoder # 輸出name的格式,並根據條件打印出BatchNorm2d-128的兩個參數 models.Conv2_3_64.weight models.Conv2_64_128.weight models.BatchNorm2d_128.weight Parameter containing: tensor([1.0074, 0.9865, 1.0188, 1.0015, 0.9757, 1.0393, 0.9813, 1.0135, 1.0227, 0.9903, 1.0490, 1.0102, 0.9920, 0.9878, 1.0060, 0.9944, 0.9993, 1.0139, 0.9987, 0.9888, 0.9816, 0.9951, 1.0017, 0.9818, 0.9922, 0.9627, 0.9883, 0.9985, 0.9759, 0.9962, 1.0183, 1.0199, 1.0033, 1.0475, 0.9586, 0.9916, 1.0354, 0.9956, 0.9998, 1.0022, 1.0307, 1.0141, 1.0062, 1.0082, 1.0111, 0.9683, 1.0372, 0.9967, 1.0157, 1.0299, 1.0352, 0.9961, 0.9901, 1.0274, 0.9727, 1.0042, 1.0278, 1.0134, 0.9648, 0.9887, 1.0225, 1.0175, 1.0002, 0.9988, 0.9839, 1.0023, 0.9913, 0.9657, 1.0404, 1.0197, 1.0221, 0.9925, 0.9962, 0.9910, 0.9865, 1.0342, 1.0156, 0.9688, 1.0015, 1.0055, 0.9751, 1.0304, 1.0132, 0.9778, 0.9900, 1.0092, 0.9745, 1.0067, 1.0077, 1.0057, 1.0117, 0.9850, 1.0309, 0.9918, 0.9945, 0.9935, 0.9746, 1.0366, 0.9913, 0.9564, 1.0071, 1.0370, 0.9774, 1.0126, 1.0040, 0.9946, 1.0080, 1.0126, 0.9761, 0.9811, 0.9974, 0.9992, 1.0338, 1.0104, 0.9931, 1.0204, 1.0230, 1.0255, 0.9969, 1.0079, 1.0127, 0.9816, 1.0132, 0.9884, 0.9691, 0.9922, 1.0166, 0.9980], requires_grad=True) models.BatchNorm2d_128.bias Parameter containing: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True) models.Conv2_128_256.weight models.BatchNorm2d_256.weight models.BatchNorm2d_256.bias models.Conv2_256_512.weight models.BatchNorm2d_512.weight models.BatchNorm2d_512.bias models.Conv2_512_1024.weight models.BatchNorm2d_1024.weight
models.BatchNorm2d_1024.bias models.Conv2_1024_2048.weight models.BatchNorm2d_2048.weight models.BatchNorm2d_2048.bias models.Conv2_2048_100.weight
2.直接在定義網絡時定義
import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F class Discriminator(nn.Module): """ 6層全連接層 """ def __init__(self, z_dim): super(Discriminator, self).__init__() self.z_dim = z_dim self.net = nn.Sequential( nn.Linear(z_dim, 1000), nn.LeakyReLU(0.2, True), nn.Linear(1000, 1000), nn.LeakyReLU(0.2, True), nn.Linear(1000, 1000), nn.LeakyReLU(0.2, True), nn.Linear(1000, 1000), nn.LeakyReLU(0.2, True), nn.Linear(1000, 1000), nn.LeakyReLU(0.2, True), nn.Linear(1000, 2), ) self.weight_init() # 參數初始化 def weight_init(self, mode='normal'): if mode == 'kaiming': initializer = kaiming_init elif mode == 'normal': initializer = normal_init for block in self._modules: for m in self._modules[block]: initializer(m) def forward(self, z): return self.net(z).squeeze() def kaiming_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0) def normal_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.normal_(m.weight, 0, 0.02) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0)
然后調用即可