前言:
self.dagmm.train() #self.dagmm = DaGMM(self.gmm_k) .DaGMM中沒有重寫父類的train方法
Module類的構造函數:
1 def __init__(self): 2 #初始化模塊內部狀態,由nn.Module和ScriptModule共享 3 torch._C._log_api_usage_once("python.nn_module") 4 5 self.training = True 6 self._parameters = OrderedDict() 7 self._buffers = OrderedDict() 8 self._backward_hooks = OrderedDict() 9 self._forward_hooks = OrderedDict() 10 self._forward_pre_hooks = OrderedDict() 11 self._state_dict_hooks = OrderedDict() 12 self._load_state_dict_pre_hooks = OrderedDict() 13 self._modules = OrderedDict()
其中training屬性表示BatchNorm與Dropout層在訓練階段和測試階段中采取的策略不同,通過判斷training值來決定前向傳播策略。
對於一些含有BatchNorm,Dropout等層的模型,在訓練時使用的forward和驗證時使用的forward在計算上不太一樣。在前向訓練的過程中指定當前模型是訓練還是在驗證。
model.train()#使用BatchNormalizetion()和Dropout() model.eval()#不使用BatchNormalization()和Dropout()
這兩個方法的定義源碼如下:
1、train()
def train(self, mode=True): r"""將模塊設置為訓練模式。 這只對某些模塊有任何影響.如受影響,具體模塊在訓練/評估模式下的行為詳見具體模塊文檔,e.g. :class:`Dropout`, :class:`BatchNorm`,etc. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self #返回自己(我們自己定義的model)
2、eval()
def eval(self): r"""將模塊設為評估模式。 這只對某些模塊有任何影響。如受影響,具體模塊在培訓/評估模式下的行為詳見具體模塊文檔, e.g. :class:`Dropout`, :class:`BatchNorm`,etc. """ return self.train(False)
從源碼中可以看出,train和eval()方法將本層及子層的training屬性同時設為true或false。
具體如下:
net.train() # 將本層及子層的training設定為True net.eval() # 將本層及子層的training設定為False net.training = True # 注意,對module的設置僅僅影響本層,子module不受影響
參考:
參考1:torch.nn.Module中的training屬性詳情,與Module.train()和Module.eval()的關系