def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
# 也可以判斷是否為conv2d,使用相應的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# 是否為批歸一化層
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# 2. 初始化網絡結構
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
# 3. 將weight_init應用在子模塊上
model.apply(weight_init)
自定義參數初始化方法
原博客:https://blog.csdn.net/dss_dssssd/article/details/83990511
對某一層進行初始化
https://blog.csdn.net/VictoriaW/article/details/72872036
預訓練部分不想加載
