1、問題描述:
pytorch中,在測試階段進行前向推斷運行時,隨着for循環次數的增加,顯存不斷累加變大,最終導致顯存溢出。
2、解決方法:
使用如下代碼處理輸入數據:
假設X為模型的輸入
X = X.cuda()
input_blobs = Variable(X, volatile=True)
output = model(input_blobs)
注意: 一定要設置 volatile=True 該參數,否則在for循環過程中,顯存會不斷累加。
1、問題描述:
pytorch中,在測試階段進行前向推斷運行時,隨着for循環次數的增加,顯存不斷累加變大,最終導致顯存溢出。
2、解決方法:
使用如下代碼處理輸入數據:
假設X為模型的輸入
X = X.cuda()
input_blobs = Variable(X, volatile=True)
output = model(input_blobs)
注意: 一定要設置 volatile=True 該參數,否則在for循環過程中,顯存會不斷累加。
本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。