本文主要是記錄下,使用PytorchLightning這個如何進行深度學習的訓練,記錄一下本人平常使用這個框架所需要注意的地方,由於框架的理解深入本文會時不時進行更新(第三部分的常見問題會是不是的更新走的),本文深度參考以下兩個網站pytorch_lightning 全程筆記 、Pytorch Lightning 完全攻略如果大家覺得本文寫得不是很清楚,大家可以進一步看看這兩篇文章。
一、框架使用方案
正如網絡上大家介紹的那樣,PL框架可以讓人專心在模型內部的研究。我們在復雜的項目中,可能會出現多個模型,並且模型多個模型之間存在着許多的聯系,如果在項目中想要更換某些模型model,會導致重寫很多代碼。但是如果采用PL框架,那么這將會是一件比較容易的事情。根據Pytorch Lightning 完全攻略這篇文章的推薦,我建議采用以下的代碼風格:
root-
|-dataModule |-__init__.py |-data_interface.py |-xxxdataset1.py |-xxxdataset2.py |-... |-modelModule |-__init__.py |-model_interface.py |-xxxmodel1.py |-xxxmodel2.py |-... |-train.py
其中把dataModule和modelModule寫成python包,這兩個包的__init__.py分別是:
- from .data_interface import DInterface
- from .model_interface import MInterface
在DInterface和MInterface分別是data_interface.py和model_interface.py中創建的類,他們兩個分別就是
- class DInterface(pl.LightningDataModule): 用於所有數據集的接口,在setup()方法中初始化你准備好的xxxdataset1.py,xxxdataset2.py中定義的torch.utils.data.Dataset類。在train_dataloader,val_dataloader,test_dataloader這幾個方法中載入Dataloader即可。
- class MInterface(pl.LightningModule): 用作模型的接口,在
__init__()
函數中import你准備好的xxxmodel2.py,xxxmodel1.py這些模型。重寫training_step方法,validation_step方法,configure_optimizers方法。
當大家在更改模型的時候只需要在對應的模塊上進行更改即可,最后train.py主要功能就是讀取參數,和調用dataModule和modelModule這兩個包進行實例化DInterface和MInterface,當然一些PL框架的回調函數也需要在train.py里進行定義。
二、框架基本模塊(Module)
2.1 LightningModule
LightningModule必須包含的部分是init()和training_step(self, batch, batch_idx),其中init()主要是進行模型的初始化和定義(不需要定義數據集等)。training_step(...)主要是進行定義每個batch數據的處理函數。
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
這里的batch就是從dataloader中出來的一個batch的數據,類似於for batch in dataloader :。可以看到這個函數的返回值是一個loss,文檔中有提到這個方法一定要返回一個loss,如果是返回一個字典,那么必須包含“loss”這個鍵。當然也有例外,當我們重寫training_step_end()方法的時候就不用training_step必須返回一個loss了,此時可以返回任意的東西,但是要注意training_step_end()方法就必須返回一個loss,至於training_step_end()主要是在使用多GPU訓練的時候需要重寫該方法,主要是進行損失的匯總。
除了training_step,我們還有validation_step,test_step,其中test_step不會在訓練中調用,而validation_step則是對測試數據進行模型推理,一般在這個步驟里可以用self.log進行記錄某些值,例如:
def validation_step(self, batch, batch_idx):
pre = model(batch)
loss = self.lossfun(...)
# log記錄
self.log('val_loss',loss, on_epoch=True, prog_bar=True, logger=True)
self.log()中常用參數以下:
- prog_bar:如果是True,該值將會顯示在進度條上
- logger:如果是True,將會記錄到logger器中(會顯示在tensorboard上)
2.2 LightningDataModule
這一個類必須包含的部分是setup(self, stage=None)方法,train_dataloader()方法。
- setup(self, stage=None):主要是進行Dataset的實例化,包括但不限於進行數據集的划分,划分成訓練集和測試集,一般來說都是Dataset類
- train_dataloader():很簡單,只需要返回一個DataLoader類即可。
有些時候也需要定義collate_fn函數,對一個傳入DataLoader的Dataset進處理。
三、常見問題
參考網站:
pytorch_lightning 全程筆記 - 知乎 (zhihu.com)
Pytorch Lightning 完全攻略 - 知乎 (zhihu.com)
PyTorch Lightning初步教程(上) - 知乎 (zhihu.com)
“簡約版”Pytorch —— Pytorch-Lightning詳解_@YangZai的博客-CSDN博客
201024-5步PyTorchLightning中設置並訪問tensorboard_專注機器學習之路-CSDN博客
PyTorch Ligntning】快速上手簡明指南_聞韶-CSDN博客
PyTorch Lightning 工具學習 - 知乎 (zhihu.com)