model.apply(weights_init_normal)


model.apply(weights_init_normal)方法

應用把方法應用於每一個module,這里意思是進行初始化

  1. def weights_init_normal(m): 
  2. classname = m.__class__.__name__ 
  3. if classname.find("Conv") != -1: 
  4. torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 
  5. elif classname.find("BatchNorm2d") != -1: 
  6. torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 
  7. torch.nn.init.constant_(m.bias.data, 0.0) 

這里的意思是選擇module是conv或者是batchNorm2d的層進行初始化


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM