問題起因:筆者想把別人的torch的代碼復制到筆者的代碼框架下,從而引起的顯存爆炸問題
該bug在困擾了筆者三天的情況下,和學長一同解決了該bug,故在此記錄這次艱辛的debug之路。
嘗試思路1:檢查是否存在保留loss的情況下是否使用了 item() 取值,經檢查,並沒有
嘗試思路2:按照網上的說法,添加兩行下面的代碼:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
實測發現並沒有用。
嘗試思路3:及時刪除臨時變量和清空顯存的cache,例如每次訓練一個batch就清除模型的輸入輸出。
del inputs,loss gc.collect() torch.cuda.empty_cache()
這樣確實使得模型能夠多訓練幾個epoch,但依舊沒有解決顯存持續增長的問題,而且由於頻繁使用torch.cuda.empty_cache(),導致模型一個epoch的訓練時長翻了3倍多。
嘗試思路4:重新核對原模型代碼,打印模型中所有parameters和register_buffer的require_grad,終於發現是因為模型中的某個register_buffer在訓練過程中,它的require_grad本應該為False,然而遷移到我代碼上的實際訓練過程中變成了True,而這個buffer的占用數據空間也不大,可能是因為變為True之后,導致在顯存中一直被保留,從而最終導致顯存溢出。再將那個buffer在forward函數里的操作放在torch.no_grad()上下文中,問題解決!


總結:如果訓練時顯存占用持續增加,需要謹慎的檢查forward函數中的操作,尤其是在編寫復雜代碼的時候,更需要更細致的檢查!
