代碼已上傳到github:https://github.com/taishan1994/tensorflow-text-classification
往期精彩:
利用TfidfVectorizer進行中文文本分類(數據集是復旦中文語料)
利用transformer進行中文文本分類(數據集是復旦中文語料)
基於tensorflow的中文文本分類
數據集:復旦中文語料,包含20類
數據集下載地址:https://www.kesci.com/mw/dataset/5d3a9c86cf76a600360edd04/content
數據集下載好之后將其放置在data文件夾下;
修改globalConfig.py中的全局路徑為自己項目的路徑;
處理后的數據和已訓練好保存的模型,在這里可以下載:
鏈接:https://pan.baidu.com/s/1ZHzO5e__-WFYAYFIt2Kmsg 提取碼:vvzy
目錄結構:
|--checkpint:保存模型目錄
|--|--transformer:transformer模型保存位置;
|--config:配置文件;
|--|--fudanConfig.py:包含訓練配置、模型配置、數據集配置;
|--|--globaConfig.py:全局配置文件,主要是全局路徑、全局參數等;
|-- data:數據保存位置;
|--|--|--Fudan:復旦數據;
|--|--|--train:訓練數據;
|--|--|--answer:測試數據;
|--dataset:創建數據集,對數據進行處理的一些操作;
|--images:結果可視化圖片保存位置;
|--models:模型保存文件;
|--process:對原始數據進行處理后的數據;
|--tensorboard:tensorboard可視化文件保存位置,暫時未用到;
|--utils:輔助函數保存位置,包括word2vec訓練詞向量、評價指標計算、結果可視化等;
|--main.py:主運行文件,選擇模型、訓練、測試和預測;
初始配置:
- 詞嵌入維度:200
- 學習率:0.001
- epoch:50
- 詞匯表大小:6000+2(加2是PAD和UNK)
- 文本最大長度:600
- 每多少個step進行驗證:100
- 每多少個step進行存儲模型:100
環境:
- python=>=3.6
- tensorflow==1.15.0
當前支持的模型:
- bilstm
- bilstm+attention
- textcnn
- rcnn
- transformer
說明
數據的輸入格式:
(1)分詞后去除掉停止詞,再對詞語進行詞頻統計,取頻數最高的前6000個詞語作為詞匯表;
(2)像詞匯表中加入PAD和UNK,實際上的詞匯表的詞語總數為6000+2=6002;
(3)當句子長度大於指定的最大長度,進行裁剪,小於最大長度,在句子前面用PAD進行填充;
(4)如果句子中的詞語在詞匯表中沒有出現則用UNK進行代替;
(5)輸入到網絡中的句子實際上是進行分詞后的詞語映射的id,比如:
(6)輸入的標簽是要經過onehot編碼的;
"""
"我喜歡上海",
"我喜歡打羽毛球",
"""
詞匯表:['我','喜歡','打','上海','羽毛球'],對應映射:[2,3,4,5,6],0對應PAD,1對應UNK
得到:
[
[0,2,3,5],
[2,3,4,6],
]
python main.py --model transformer --saver_dir checkpoint/transformer --save_png images/transformer --train --test --predict
參數說明:
- --model:選擇模型,可選[transformer、bilstm、bilstmattn、textcnn、rcnn]
- --saver_dir:模型保存位置,一般是checkpoint+模型名稱
- --save_png:結果可視化保存位置,一般是images+模型名稱
- --train:是否進行訓練,默認為False
- --test:是否進行測試,默認為False
- --predict:是否進行預測,默認為False
結果
以transformer為例:
部分訓練結果:
2020-11-01T10:43:16.955322, step: 1300, loss: 5.089711, acc: 0.8546,precision: 0.3990, recall: 0.4061, f_beta: 0.3977 * Epoch: 83 train: step: 1320, loss: 0.023474, acc: 0.9922, recall: 0.8444, precision: 0.8474, f_beta: 0.8457 Epoch: 84 train: step: 1340, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500 Epoch: 85 train: step: 1360, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500 Epoch: 86 Epoch: 87 train: step: 1380, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500 Epoch: 88 train: step: 1400, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000 開始驗證。。。 2020-11-01T10:44:07.347359, step: 1400, loss: 5.111372, acc: 0.8506,precision: 0.4032, recall: 0.4083, f_beta: 0.3982 * Epoch: 89 train: step: 1420, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500 Epoch: 90 train: step: 1440, loss: 0.000000, acc: 1.0000, recall: 0.5500, precision: 0.5500, f_beta: 0.5500 Epoch: 91 Epoch: 92 train: step: 1460, loss: 0.000000, acc: 1.0000, recall: 0.7000, precision: 0.7000, f_beta: 0.7000 Epoch: 93 train: step: 1480, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500, f_beta: 0.7500 Epoch: 94 train: step: 1500, loss: 0.000000, acc: 1.0000, recall: 0.6000, precision: 0.6000, f_beta: 0.6000 開始驗證。。。 2020-11-01T10:44:57.645305, step: 1500, loss: 5.206666, acc: 0.8521,precision: 0.4003, recall: 0.4040, f_beta: 0.3957 Epoch: 95 train: step: 1520, loss: 0.000000, acc: 1.0000, recall: 0.6000, precision: 0.6000, f_beta: 0.6000 Epoch: 96 Epoch: 97 train: step: 1540, loss: 0.000000, acc: 1.0000, recall: 0.7500, precision: 0.7500,