pytorch lightning使用(簡要介紹)


0. 簡介
pytorch lightning通過提供LightningModule和LightningDataModule,使得在用pytorch編寫網絡模型時,加載數據、分割數據集、訓練、驗證、測試、計算指標的代碼全部都能很好的組織起來,顯得主程序調用時,代碼簡潔可讀性大幅度提升。

1. pytorch lightning的安裝

1 pip install pytorch-lightning
2 conda install pytorch-lightning -c conda-forge

2. 定義一個網絡模型模型:LightningModule

通過繼承LightningModule,並實現幾個關鍵的函數,使得模型在訓練、驗證和測試過程中能進行模塊化調用,具體細節完全被自定義的函數封裝,整體十分簡潔。定義一個LightningModule的基類,可以實現的函數如下:

 1 from pytorch_lightning import LightningModule
 2  
 3 class MyModel(LightningModule):
 4     """
 5     The only required methods in the LightningModule are:
 6     init
 7     training_step
 8     configure_optimizers
 9     """
10     def __init__(self, *args, **kwargs): pass
11     def forward(self, *args, **kwargs): pass
12     def training_step(self, batch, batch_idx, optimizer_idx, hiddens): pass
13     def training_step_end(self, *args, **kwargs): pass # 接受train_step的返回值
14     def training_epoch_end(self, outputs): pass # 接受train_step一整個epoch的返回值的列表
15     def validation_step(self, batch, batch_idx, dataloader_idx): pass # model.eval() and torch.no_grad() are called automatically
16     def validation_step_end(self, *args, **kwargs): pass # 接受validation_step的返回值
17     def validation_epoch_end(self, outputs)
18     def test_step(self, batch, batch_idx, dataloader_idx): pass # model.eval() and torch.no_grad() are called automatically
19     def test_step_end(self, *args, **kwargs): pass  # 接收test_step的返回值
20     def test_epoch_end(self, outputs): pass
21     def configure_optimizers(self, ): pass
22     def any_extra_hook(...): pass  #  指代任意其他的可重載函數

其中,必須實現的函數只有__init__() 、training_step()、configure_optimizers()。

3. 定義一個數據模型:LightningDataModule

通過定義LightningDataModule的子類,數據集分割、加載的代碼將整合在一起,可以實現的方法有:

 1 class MyDataModule(LightningDataModule):
 2     def __init__(self):
 3         super().__init__()
 4     def prepare_data(self):
 5         # download, split, etc...
 6         # only called on 1 GPU/TPU in distributed
 7     def setup(self,stage:str):  # stage: "fit", "test", 【暫時不知道驗證步驟叫什么名字,可以自己打印一下】
 8         # make assignments here (val/train/test split)
 9         # called on every process in DDP
10     def train_dataloader(self):
11         train_split = Dataset(...)
12         return DataLoader(train_split)
13     def val_dataloader(self):
14         val_split = Dataset(...)
15         return DataLoader(val_split)
16     def test_dataloader(self):
17         test_split = Dataset(...)
18         return DataLoader(test_split)

4. 使用pytorch lightning的API開始訓練

 1 def main():
 2     model = MyModule()
 3     data_module = MyDataModule()
 4     trainer = pytorch_lightning.Trainer(...)  # some arugments, 根據需要傳入你的參數
 5     trainer.fit(module, datamodule=data_module)
 6     trainer.test(module, datamodule=data_module, verbose=True)
 7  
 8  
 9 if __name__ == "__main__":
10     main()

具體實現都通過類封裝之后,主函數就顯得簡潔多了。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM