Tensorflow+RNN實現新聞文本分類
- 加載數據集
數據集cnew文件夾中有4個文件:
1.訓練集文件cnews.train.txt
2.測試集文件cnew.test.txt
3.驗證集文件cnews.val.txt
4.詞匯表文件cnews.vocab.txt
新聞文本共有10個類別,65000個樣本數據,其中訓練集50000條,測試集10000條,驗證集5000條。
輸入:從txt文本中輸入的數據為新聞類別、新聞內容,進行詞和ID的映射后,所有的詞變為詞向量。
輸出: 預測結果y為一個10維數組,數組中值的取值范圍為[0,1],使用tf.argmax(y,1),取出數組中最大值的下標,再用獨熱表示以及模型輸出轉換成數字標簽。
加載訓練集函數:
加載詞文件函數:
2.詞和ID的映射
將詞匯表中的詞映射到對應的ID,再把訓練文本中的每一條新聞中所有詞轉換成對應ID。使用測試集驗證模型預測准確度時,對測試集中的文本做同樣操作。
1) 使用列表推導式得到詞匯及其id對應的列表,並調用dict方法將列表強制轉換為字典。代碼為:word2id_dict = dict([(b, a) for a, b in enumerate(vocabulary_list)]),將1中的詞匯列表數組做了詞和ID的映射,格式為:{'<PAD>': 0, ',': 1, '的': 2, '。': 3,...}。
2) 使用列表推導式和匿名函數定義函數content2idlist,函數作用是將文章中的每個字轉換為id,代碼為:content2idList = lambda content : [word2id_dict[word] for word in content if word in word2id_dict]。
3) 使用列表推導式得到的結果是列表的列表,總列表train_idlist_list中的元素是每篇文章中的字對應的id列表;代碼為:train_idlist_list = [content2idList(content) for content in train_content_list],格式為:[[387, 1197, …, 3], …,[199, 964, …,3, 24]]
3.構建RNN網絡
1) 設置循環神經網絡的超參數
2) 將每個樣本統一長度為seq_length,train_X = kr.preprocessing.sequence.pad_sequences(train_idlist_list, sequence_length)。
3) 調用LabelEncoder對象的fit_transform方法做標簽編碼,代碼為:train_y = labelEncoder.fit_transform(train_label_list),格式為: [0 0 0 ... 9 9 9],將訓練數據的類別標簽轉換成整型。
4) 調用keras.untils庫的to_categorical方法將標簽編碼的結果再做Ont-Hot編碼,將整型標簽轉換成onehot,代碼為:train_Y = kr.utils.to_categorical(train_y, num_classes),格式為:[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] … [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]。
5) 調用tf庫的get_variable方法實例化可以更新的模型參數embedding,矩陣形狀為vocabulary_size*embedding_size,即5000*64。代碼為embedding = tf.get_variable('embedding', [vocabolary_size, embedding_size])。
6) 使用tf.nn庫的embedding_lookup方法將輸入數據做詞嵌入,代碼為:embedding_inputs = tf.nn.embedding_lookup(embedding, X_holder)。X_holder中已經設置了序列長度為150/600。得到新變量embedding_inputs的形狀為batch_size*sequence_length*embedding_size。
7) RNN層的搭建:將上述六步中處理好的數據輸入到LSTM網絡中,調用的是tf.nn.rnn_cell.BasicLSTMCell 函數,代碼為:lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden_units), num_hidden_units為隱藏層神經元的個數,實驗中設置為了128/256。還要設置一個 dropout 參數,避免過擬合,代碼為:lstm_cell = tf.contrib.rnn.DropoutWrapper(cell=lstm_cell, output_keep_prob=0.75)。
8) 將 LSTM cell 和三維的數據輸入到 tf.nn.dynamic_rnn ,目的為展開整個網絡並構建一整個 RNN 模型,代碼為:outputs, state = tf.nn.dynamic_rnn(lstm_cell, embedding_inputs,dtype=tf.float32)。
9) 獲取最后一個細胞的h,即最后一個細胞的短時記憶矩陣,代碼為:last_cell = outputs[:, -1, :]。
10) 添加第一個全連接層:調用tf.layers.dense方法,將結果賦值給變量full_connect1,形狀為batch_size*num_fc1_units,詞向量大小與全連接層神經元一致。
11) 調用tf.contrib.layers.dropout方法,防止過擬合,代碼為:full_connect1_dropout = tf.contrib.layers.dropout
(full_connect1, dropout_keep_probability)。
12) 調用tf.nn.relu方法,即激活函數,增強擬合復雜函數的能力,代碼為:full_connect1_activate = tf.nn.relu(full_connect1_dropout)。
13) 添加第二個全連接層:操作類似於第一個全連接層,但全連接層的神經元個數為10(對應新聞的10種類別),然后使用Softmax函數,將結果轉化成10個類別的概率。
14) 使用交叉熵作為損失函數,調用tf.train.AdamOptimizer方法定義優化器optimizer 學習率設置為了0.001。代碼為:
所以該文本識別模型為:輸入數據-->RNN層(LSTM模型)-->全連接層1-->全連接層2 -->輸出
4.使用訓練數據迭代訓練模型
模型迭代運行5000/10000次,從訓練集中選取batch_size大小,即64個樣本做批量梯度下降;每訓練100次模型,從測試集中隨機選取200個樣本,驗證一下模型的預測能力。
5.在測試集上進行准確率評估
如5中所示每訓練100次模型,在測試集中隨機選取200個樣本,驗證訓練的模型預測能力。
實驗結果記錄:
實驗中使用了門限循環單元(GRU)、長短期記憶神經網絡(LSTM)做對比。
長短期記憶LSTM是一種和GRUs不同的復雜的激活單元,作用與GRUs相似,但是在單元的結構上不一樣,最終記憶產生融合了輸入門與遺忘門的結果,且輸出門是GRUs中沒有顯性存在的門,目的為從隱層狀態分離最終的記憶。
GRU則是LSTM的一個變體, GRU保持了LSTM的效果同時又使結構更加簡單。
BasicLSTMCell實驗結果:
模型方法 |
BasicLSTMCell |
BasicLSTMCell |
BasicLSTMCell |
BasicLSTMCell |
訓練次數 |
5000 |
5000 |
10000 |
10000 |
序列長度 |
150 |
150 |
600 |
600 |
隱藏層神經元 |
128 |
256 |
128 |
256 |
每次測試樣本數 |
200 |
200 |
200 |
200 |
准確度(5000/10000次訓練) |
0.8700 |
0.8950 |
0.9250 |
0.9350 |
准確度(最大) |
0.9150 |
0.9100 |
0.9250 |
0.9800 |
GRUCell實驗結果:
模型方法 |
GRUCell |
GRUCell |
GRUCell |
GRUCell |
訓練次數 |
5000 |
5000 |
10000 |
10000 |
序列長度 |
150 |
150 |
600 |
600 |
隱藏層神經元 |
128 |
256 |
128 |
256 |
每次測試樣本數 |
200 |
200 |
200 |
200 |
准確度(5000/10000次訓練) |
0.9400 |
0.9300 |
0.9500 |
0.9500 |
准確度(最大) |
0.9550 |
0.9300 |
0.9800 |
0.9750 |
實驗對比可知:
1) 在訓練次數、序列長度、隱藏層神經元相同的條件下,BasicLSTMCell的預測准確度,均低於GRUCell,原因為:TensorFlow中的BasicLSTMCell是一種參考或者標准實現,解決實際問題中不是首選;GRU 參數相對少更容易收斂。
2) 增加訓練次數(5000à1000),模型的預測准確度有提升,因為隨着訓練次數的增多,有用的信息在 LSTM 中進行了保存。
3) LSTM單元的數量很大程度上取決於輸入文本的平均長度,更多的單元數量可以幫助模型存儲更多的文本信息,但模型的訓練時間就會增加很多。