Pytorch文本分類(imdb數據集),含DataLoader數據加載,最優模型保存


用pytorch進行文本分類,數據集為keras內置的imdb影評數據(二分類),代碼包含六個部分(詳見代碼)

代碼地址為:pytorch-imdb-classification 歡迎star~

使用環境:

pytorch:1.1.0

cuda:10.0

gpu:RTX2070

 

(1)導入相應的庫、定義常量以及加載imdb數據

 

(2)使用DataLoader加載數據

 

(3)定義LSTM模型用於文本二分類

 

(4)定義訓練函數和測試函數

 

 

(5)開始模型的訓練(並保存最優模型權重),訓練較快,2min左右

 

(6)加載模型權重並測試

 

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM