前言:
self.dagmm.train() # self.dagmm = DaGMM(self.gmm_k) enc, dec, z, gamma = self.dagmm(input_data)
在Pytorch中没用调用模型的forward()前向传播,只有实例化把参数传入。
定义一个模型:
1 class Module(nn.Module): 2 def __init__(self): 3 super(Module, self).__init__() 4 # ...... 5 6 def forward(self, x): 7 # ...... 8 return x 9 10 data = ..... #输入数据 11 # 实例化一个对象 12 module = Module() 13 14 # 前向传播 直接把输入传入实列化 15 module(data) #没有使用module.forward(data) 16 #实际上module(data) 等价于module.forward(data)
等价的原因是因为 python class 中的__call__可以让类像函数一样调用,当执行model(x)的时候,底层自动调用forward方法计算结果:
class A(): def __call__(self): print('i can be called like a function') a = A() a() >>>i can be called like a function
在__call__里可以调用其它函数:
class A(object): def __call__(self, param): print('我在__call__中,传入参数', param) res = self.forward(param) return res def forward(self, x): print('我在forward函数中,传入参数类型是值为: ', x) return x a = A() y = a('i') print("传入的参数是:", y) #我在__call__中,传入参数 i 我在forward函数中,传入参数类型是值为: i 传入的参数是: i
附录:
附录1:
可调用的对象:
关于__call__方法,不得不先提一个概念,就是可调用对象(callable),我们平时自定义的函数、内置函数和类都属于可调用对象,但凡是可以把一对括号()应用到某个对象身上都可称之为可调用对象,判断对象是否为可调用对象可以用函数callable。
如果在类中实现了__call__方法,那么实例对象也将成为一个可调用对象。
你也许已经知道,在Python中,方法也是一种高等的对象。这就意味着它们可以被传递到方法中,就像其他对象一样。这是个非常惊人的特征。在Python中,一个特殊的魔术方法可以让类的实例的行为表现得像函数一样,你可以调用他们,将一个函数当做一个参数传到另一个函数中等等。这是一个非常强大的特性,__call__(self, [args...])。
参考:
参考1:Pytorch模型中nn.Model中的forward()前向传播不调用 解释