pytorch-使用.apply 和 init.normal_()模拟net网络的参数初始化过程


# 构建apply函数体
from torch.nn import init
import torch
class A:
    def __init__(self):
        self.weight = torch.tensor([0.0, 0.0])
        self.bias = 0
        pass
    def apply(self, func):
        func(self)


B = A()


def init_weight(B):
    def init_value(m):
        if hasattr(m, 'weight'):
            init.normal_(m.weight, 0.0, 0.02)

    B.apply(init_value)


init_weight(B)
print(B.weight)

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM