Pytorch的nn.Module類中定義的實例方法apply()


參考文檔Module — PyTorch 1.7.0 documentation

1  @torch.no_grad()
2  def init_weights(m):
3      print(m)
4      if type(m) == nn.Linear:
5          m.weight.fill_(1.0)
6          print(m.weight)
7  net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
8  net.apply(init_weights)

net類及其子類都會調用 init_weights() 方法


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM