torch.nn.Module.train() & torch.nn.Module.eval()


前言:

self.dagmm.train() #self.dagmm = DaGMM(self.gmm_k) .DaGMM中没有重写父类的train方法

 Module类的构造函数:

 1     def __init__(self):
 2         #初始化模块内部状态,由nn.Module和ScriptModule共享
 3         torch._C._log_api_usage_once("python.nn_module")
 4 
 5         self.training = True
 6         self._parameters = OrderedDict()
 7         self._buffers = OrderedDict()
 8         self._backward_hooks = OrderedDict()
 9         self._forward_hooks = OrderedDict()
10         self._forward_pre_hooks = OrderedDict()
11         self._state_dict_hooks = OrderedDict()
12         self._load_state_dict_pre_hooks = OrderedDict()
13         self._modules = OrderedDict()

  其中training属性表示BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。

 对于一些含有BatchNorm,Dropout等层的模型,在训练时使用的forward和验证时使用的forward在计算上不太一样。在前向训练的过程中指定当前模型是训练还是在验证

model.train()#使用BatchNormalizetion()和Dropout()
model.eval()#不使用BatchNormalization()和Dropout()

 这两个方法的定义源码如下:

  1、train()

    def train(self, mode=True):
        r"""将模块设置为训练模式。

        这只对某些模块有任何影响.如受影响,具体模块在训练/评估模式下的行为详见具体模块文档,e.g. :class:`Dropout`, :class:`BatchNorm`,etc.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self #返回自己(我们自己定义的model)

  2、eval()

    def eval(self):
        r"""将模块设为评估模式。
        这只对某些模块有任何影响。如受影响,具体模块在培训/评估模式下的行为详见具体模块文档, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.
        """
        return self.train(False)

 从源码中可以看出,train和eval()方法将本层及子层的training属性同时设为true或false。

 具体如下:

net.train() # 将本层及子层的training设定为True
net.eval() # 将本层及子层的training设定为False
net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响

 

参考:

参考1:torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM