model.apply(weights_init_normal)


model.apply(weights_init_normal)方法

应用把方法应用于每一个module,这里意思是进行初始化

  1. def weights_init_normal(m): 
  2. classname = m.__class__.__name__ 
  3. if classname.find("Conv") != -1: 
  4. torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 
  5. elif classname.find("BatchNorm2d") != -1: 
  6. torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 
  7. torch.nn.init.constant_(m.bias.data, 0.0) 

这里的意思是选择module是conv或者是batchNorm2d的层进行初始化


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM