torch.nn.Module.train() & torch.nn.Module.eval()


前言:

self.dagmm.train() #self.dagmm = DaGMM(self.gmm_k) .DaGMM中沒有重寫父類的train方法

 Module類的構造函數:

 1     def __init__(self):
 2         #初始化模塊內部狀態,由nn.Module和ScriptModule共享
 3         torch._C._log_api_usage_once("python.nn_module")
 4 
 5         self.training = True
 6         self._parameters = OrderedDict()
 7         self._buffers = OrderedDict()
 8         self._backward_hooks = OrderedDict()
 9         self._forward_hooks = OrderedDict()
10         self._forward_pre_hooks = OrderedDict()
11         self._state_dict_hooks = OrderedDict()
12         self._load_state_dict_pre_hooks = OrderedDict()
13         self._modules = OrderedDict()

  其中training屬性表示BatchNorm與Dropout層在訓練階段和測試階段中采取的策略不同,通過判斷training值來決定前向傳播策略。

 對於一些含有BatchNorm,Dropout等層的模型,在訓練時使用的forward和驗證時使用的forward在計算上不太一樣。在前向訓練的過程中指定當前模型是訓練還是在驗證

model.train()#使用BatchNormalizetion()和Dropout()
model.eval()#不使用BatchNormalization()和Dropout()

 這兩個方法的定義源碼如下:

  1、train()

    def train(self, mode=True):
        r"""將模塊設置為訓練模式。

        這只對某些模塊有任何影響.如受影響,具體模塊在訓練/評估模式下的行為詳見具體模塊文檔,e.g. :class:`Dropout`, :class:`BatchNorm`,etc.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self #返回自己(我們自己定義的model)

  2、eval()

    def eval(self):
        r"""將模塊設為評估模式。
        這只對某些模塊有任何影響。如受影響,具體模塊在培訓/評估模式下的行為詳見具體模塊文檔, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.
        """
        return self.train(False)

 從源碼中可以看出,train和eval()方法將本層及子層的training屬性同時設為true或false。

 具體如下:

net.train() # 將本層及子層的training設定為True
net.eval() # 將本層及子層的training設定為False
net.training = True # 注意,對module的設置僅僅影響本層,子module不受影響

 

參考:

參考1:torch.nn.Module中的training屬性詳情,與Module.train()和Module.eval()的關系


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM