pytorch-使用.apply 和 init.normal_()模擬net網絡的參數初始化過程


# 構建apply函數體
from torch.nn import init
import torch
class A:
    def __init__(self):
        self.weight = torch.tensor([0.0, 0.0])
        self.bias = 0
        pass
    def apply(self, func):
        func(self)


B = A()


def init_weight(B):
    def init_value(m):
        if hasattr(m, 'weight'):
            init.normal_(m.weight, 0.0, 0.02)

    B.apply(init_value)


init_weight(B)
print(B.weight)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM