Pytorch——Pytorch Lightning框架使用手冊


  本文主要是記錄下,使用PytorchLightning這個如何進行深度學習的訓練,記錄一下本人平常使用這個框架所需要注意的地方,由於框架的理解深入本文會時不時進行更新(第三部分的常見問題會是不是的更新走的),本文深度參考以下兩個網站pytorch_lightning 全程筆記 Pytorch Lightning 完全攻略如果大家覺得本文寫得不是很清楚,大家可以進一步看看這兩篇文章。

一、框架使用方案

  正如網絡上大家介紹的那樣,PL框架可以讓人專心在模型內部的研究。我們在復雜的項目中,可能會出現多個模型,並且模型多個模型之間存在着許多的聯系,如果在項目中想要更換某些模型model,會導致重寫很多代碼。但是如果采用PL框架,那么這將會是一件比較容易的事情。根據Pytorch Lightning 完全攻略這篇文章的推薦,我建議采用以下的代碼風格:

root-
    |-dataModule |-__init__.py |-data_interface.py |-xxxdataset1.py |-xxxdataset2.py |-... |-modelModule |-__init__.py |-model_interface.py |-xxxmodel1.py |-xxxmodel2.py |-... |-train.py

  其中把dataModule和modelModule寫成python包,這兩個包的__init__.py分別是:

  • from .data_interface import DInterface
  • from .model_interface import MInterface

  在DInterface和MInterface分別是data_interface.pymodel_interface.py中創建的類,他們兩個分別就是

  • class DInterface(pl.LightningDataModule): 用於所有數據集的接口,在setup()方法中初始化你准備好的xxxdataset1.py,xxxdataset2.py中定義的torch.utils.data.Dataset類。在train_dataloader,val_dataloader,test_dataloader這幾個方法中載入Dataloader即可。
  • class MInterface(pl.LightningModule): 用作模型的接口,在__init__()函數中import你准備好的xxxmodel2.py,xxxmodel1.py這些模型。重寫training_step方法,validation_step方法,configure_optimizers方法。

  當大家在更改模型的時候只需要在對應的模塊上進行更改即可,最后train.py主要功能就是讀取參數,和調用dataModule和modelModule這兩個包進行實例化DInterface和MInterface,當然一些PL框架的回調函數也需要在train.py里進行定義。

二、框架基本模塊(Module)

2.1 LightningModule

  LightningModule必須包含的部分是init()和training_step(self, batch, batch_idx),其中init()主要是進行模型的初始化和定義(不需要定義數據集等)。training_step(...)主要是進行定義每個batch數據的處理函數。

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

  這里的batch就是從dataloader中出來的一個batch的數據,類似於for batch in dataloader :。可以看到這個函數的返回值是一個loss,文檔中有提到這個方法一定要返回一個loss,如果是返回一個字典,那么必須包含“loss”這個鍵。當然也有例外,當我們重寫training_step_end()方法的時候就不用training_step必須返回一個loss了,此時可以返回任意的東西,但是要注意training_step_end()方法就必須返回一個loss,至於training_step_end()主要是在使用多GPU訓練的時候需要重寫該方法,主要是進行損失的匯總。

  除了training_step,我們還有validation_step,test_step,其中test_step不會在訓練中調用,而validation_step則是對測試數據進行模型推理,一般在這個步驟里可以用self.log進行記錄某些值,例如:

def validation_step(self, batch, batch_idx): 
    pre = model(batch)
    loss = self.lossfun(...)
    # log記錄
    self.log('val_loss',loss, on_epoch=True, prog_bar=True, logger=True)

  self.log()中常用參數以下:

  • prog_bar:如果是True,該值將會顯示在進度條上
  • logger:如果是True,將會記錄到logger器中(會顯示在tensorboard上)

2.2 LightningDataModule

  這一個類必須包含的部分是setup(self, stage=None)方法,train_dataloader()方法。

  • setup(self, stage=None):主要是進行Dataset的實例化,包括但不限於進行數據集的划分,划分成訓練集和測試集,一般來說都是Dataset類
  • train_dataloader():很簡單,只需要返回一個DataLoader類即可。

  有些時候也需要定義collate_fn函數,對一個傳入DataLoader的Dataset進處理。

 三、常見問題

 

 

 

參考網站:

pytorch_lightning 全程筆記 - 知乎 (zhihu.com)

Pytorch Lightning 完全攻略 - 知乎 (zhihu.com)

PyTorch Lightning初步教程(上) - 知乎 (zhihu.com)

 “簡約版”Pytorch —— Pytorch-Lightning詳解_@YangZai的博客-CSDN博客

 201024-5步PyTorchLightning中設置並訪問tensorboard_專注機器學習之路-CSDN博客

PyTorch Ligntning】快速上手簡明指南_聞韶-CSDN博客

PyTorch Lightning 工具學習 - 知乎 (zhihu.com)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM