在mini-batch訓練中使用tqdm來創建進度條


tqdm是python中的一個用來供我們創建進度條的庫。在進行深度學習的研究時,我們可以使用這個庫為我們直觀地展示當前的訓練進度,下面來說說如何在mini-batch優化中使用這個庫。

我希望程序能夠在每一個epoch都使用進度條來顯示當前epoch的訓練情況,我使用的代碼如下:

1 from tqdm import tqdm #從tqdm庫中導入tadm類
2 
3 for epoch in range(epochs): #訓練輪次
4     with tqdm(total = batch_num, desc=f'Epoch {epoch+1}/{epochs}', unit='it') as pbar: #創建一個進度條
5         for batch_idx in range(batch_num): #mini-batch訓練
6             ...
7             pbar.set_postfix({'batch_loss:'loss}) #在進度條后顯示當前batch的損失
8             pbar.update(1) #更當前進度,1表示完成了一個batch的訓練

所得到的的進度條如下圖所示:

“from tqdm import tqdm”就是從tqdm庫中導入tqdm類。一開始我寫成了“import tqdm”,導致程序報錯,所以這一點要注意。

"with tqdm(total = batch_num, desc=f'Epoch {epoch+1}/{epochs}', unit='it') as pbar: #創建一個進度條"使用python的with結構來創建一個tqdm對象pbar。如果不使用with結構,就需要在一次epoch訓練的結尾調用tqdm對象的close()函數。這一語句中各個參數的意思是:

total:為一個epoch中batch的總數量,即迭代的總次數;

desc:放在進度條最前的一段描述,在此顯示的是當前的epoch及總共需要多少個epoch;

unit:迭代速度的單位,it是iteration的簡寫,這里指的是以每秒完成多少次迭代作為速度度量並顯示在進度條上,見上圖中的“1.41it/s”。

“pbar.set_postfix({'batch_loss:'loss}) #在進度條后顯示當前batch的損失”,在進度條的最后顯示當前batch的損失,如圖中的“batch_loss=1.66e+5”。

pbar.update(1)用來更新進度,這里的1指的是完成了一個batch的訓練,讓進度條加1。

上圖中的“03:21<01:21”分別顯示訓練當前batch已經花費的時間和還需要消耗的時間。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM