用pytorch進行文本分類,數據集為keras內置的imdb影評數據(二分類),代碼包含六個部分(詳見代碼)
代碼地址為:pytorch-imdb-classification 歡迎star~
使用環境:
pytorch:1.1.0
cuda:10.0
gpu:RTX2070
(1)導入相應的庫、定義常量以及加載imdb數據

(2)使用DataLoader加載數據

(3)定義LSTM模型用於文本二分類

(4)定義訓練函數和測試函數

(5)開始模型的訓練(並保存最優模型權重),訓練較快,2min左右

(6)加載模型權重並測試

