pytorch中的權值初始化


torch.nn.Module.apply(fn)

# 遞歸的調用weights_init函數,遍歷nn.Module的submodule作為參數
# 常用來對模型的參數進行初始化
# fn是對參數進行初始化的函數的句柄,fn以nn.Module或者自己定義的nn.Module的子類作為參數
# fn (Module -> None) – function to be applied to each submodule
# Returns:  self
# Return type:  Module

例子:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02) 
        # m.weight.data是卷積核參數, m.bias.data是偏置項參數
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型實例
netG.apply(weights_init) # 遞歸的調用weights_init函數,遍歷netG的submodule作為參數

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM