GPU的性能主要分為兩部分:算力和顯存,前者決定了顯卡計算的速度,后者則決定了顯卡可以同時放入多少數據用於計算。在可以使用的顯存數量一定的情況下,每次訓練能夠加載的數據更多(也就是batch size更大),則可以提高訓練效率。另外有時候數據本身也比較大(比如3D圖像、視頻等),顯存較小的情況下可能甚至batch size為1情況都無法實現,因此顯存的大小十分重要。
我們觀察Pytorch默認的浮點數存儲方式用的是torch.float32,小數點后位數更多固然能夠保證數據的精確性,但絕大多數場景其實並不需要那么精確,只保留一半的信息也不會影響結果,也就是使用torch.float16格式。由於數位減了一半,因此被稱為半精度,具體如下圖:
通過上圖很明顯的可以看到,使用半精度能夠減少顯存占用,使得顯卡可以同時加載更多數據進行計算。
半精度訓練的設置
在Pytorch中使用autocast配置半精度訓練,同時需要在下面三處加以設置:
- import autocast
from torch.cuda.amp import autocast
- 模型設置
在模型定義中,使用python的裝飾器方法,用autocast裝飾模型中的forward函數。關於裝飾器的使用,參考下面:
@autocast()
def forward(self,x):
...
return x
- 訓練過程
在訓練過程中,只需要將數據輸入模型及其之后的部分放入"with autocast():"即可:
for x in train_loader:
x = x.cuda()
with autocast():
output = model(x)
...