前言:
self.dagmm.train() #self.dagmm = DaGMM(self.gmm_k) .DaGMM中没有重写父类的train方法
Module类的构造函数:
1 def __init__(self): 2 #初始化模块内部状态,由nn.Module和ScriptModule共享 3 torch._C._log_api_usage_once("python.nn_module") 4 5 self.training = True 6 self._parameters = OrderedDict() 7 self._buffers = OrderedDict() 8 self._backward_hooks = OrderedDict() 9 self._forward_hooks = OrderedDict() 10 self._forward_pre_hooks = OrderedDict() 11 self._state_dict_hooks = OrderedDict() 12 self._load_state_dict_pre_hooks = OrderedDict() 13 self._modules = OrderedDict()
其中training属性表示BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。
对于一些含有BatchNorm,Dropout等层的模型,在训练时使用的forward和验证时使用的forward在计算上不太一样。在前向训练的过程中指定当前模型是训练还是在验证。
model.train()#使用BatchNormalizetion()和Dropout() model.eval()#不使用BatchNormalization()和Dropout()
这两个方法的定义源码如下:
1、train()
def train(self, mode=True): r"""将模块设置为训练模式。 这只对某些模块有任何影响.如受影响,具体模块在训练/评估模式下的行为详见具体模块文档,e.g. :class:`Dropout`, :class:`BatchNorm`,etc. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self #返回自己(我们自己定义的model)
2、eval()
def eval(self): r"""将模块设为评估模式。 这只对某些模块有任何影响。如受影响,具体模块在培训/评估模式下的行为详见具体模块文档, e.g. :class:`Dropout`, :class:`BatchNorm`,etc. """ return self.train(False)
从源码中可以看出,train和eval()方法将本层及子层的training属性同时设为true或false。
具体如下:
net.train() # 将本层及子层的training设定为True net.eval() # 将本层及子层的training设定为False net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响
参考:
参考1:torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系