目錄
- 概述
- 數據集合
- 代碼
- 結果展示
一、概述
在英文分類的基礎上,再看看中文分類的,是一種10分類問題(體育,科技,游戲,財經,房產,家居等)的處理。
二、數據集合
數據集為新聞,總共有四個數據文件,在/data/cnews目錄下,包括內容如下圖所示測試集,訓練集和驗證集,和單詞表(最后的單詞表cnews.vocab.txt可以不要,因為訓練可以自動產生)。數據格式:前面為類別,后面為描述內容。
訓練數據地址:鏈接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取碼:2r04
其中訓練集的格式:
vocab.txt的格式:每個字一行,其中前面加上PAD。
三、代碼
3.1 數據采集cnews_loader.py
1 # coding: utf-8 2 import sys 3 from collections import Counter 4 import numpy as np 5 import tensorflow.contrib.keras as kr 6 7 if sys.version_info[0] > 2: 8 is_py3 = True 9 else: 10 reload(sys) 11 sys.setdefaultencoding("utf-8") 12 is_py3 = False 13 14 def native_word(word, encoding='utf-8'): 15 """如果在python2下面使用python3訓練的模型,可考慮調用此函數轉化一下字符編碼""" 16 if not is_py3: 17 return word.encode(encoding) 18 else: 19 return word 20 21 def native_content(content): 22 if not is_py3: 23 return content.decode('utf-8') 24 else: 25 return content 26 27 def open_file(filename, mode='r'): 28 """ 29 常用文件操作,可在python2和python3間切換. 30 mode: 'r' or 'w' for read or write 31 """ 32 if is_py3: 33 return open(filename, mode, encoding='utf-8', errors='ignore') 34 else: 35 return open(filename, mode) 36 37 def read_file(filename): 38 """讀取文件數據""" 39 contents, labels = [], [] 40 with open_file(filename) as f: 41 for line in f: 42 try: 43 label, content = line.strip().split('\t') 44 if content: 45 contents.append(list(native_content(content))) 46 labels.append(native_content(label)) 47 except: 48 pass 49 return contents, labels 50 51 def build_vocab(train_dir, vocab_dir, vocab_size=5000): 52 """根據訓練集構建詞匯表,存儲""" 53 data_train, _ = read_file(train_dir) 54 all_data = [] 55 for content in data_train: 56 all_data.extend(content) 57 counter = Counter(all_data) 58 count_pairs = counter.most_common(vocab_size - 1) 59 words, _ = list(zip(*count_pairs)) 60 # 添加一個 <PAD> 來將所有文本pad為同一長度 61 words = ['<PAD>'] + list(words) 62 open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n') 63 64 def read_vocab(vocab_dir): 65 """讀取詞匯表""" 66 # words = open_file(vocab_dir).read().strip().split('\n') 67 with open_file(vocab_dir) as fp: 68 # 如果是py2 則每個值都轉化為unicode 69 words = [native_content(_.strip()) for _ in fp.readlines()] 70 word_to_id = dict(zip(words, range(len(words)))) 71 return words, word_to_id 72 73 def read_category(): 74 """讀取分類目錄,固定""" 75 categories = ['體育', '財經', '房產', '家居', '教育', '科技', '時尚', '時政', '游戲', '娛樂'] 76 categories = [native_content(x) for x in categories] 77 cat_to_id = dict(zip(categories, range(len(categories)))) 78 return categories, cat_to_id 79 80 def to_words(content, words): 81 """將id表示的內容轉換為文字""" 82 return ''.join(words[x] for x in content) 83 84 def process_file(filename, word_to_id, cat_to_id, max_length=600): 85 """將文件轉換為id表示""" 86 contents, labels = read_file(filename) 87 88 data_id, label_id = [], [] 89 for i in range(len(contents)): 90 data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id]) 91 label_id.append(cat_to_id[labels[i]]) 92 # 使用keras提供的pad_sequences來將文本pad為固定長度 93 x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length) 94 y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 將標簽轉換為one-hot表示 95 return x_pad, y_pad 96 97 def batch_iter(x, y, batch_size=64): 98 """生成批次數據""" 99 data_len = len(x) 100 num_batch = int((data_len - 1) / batch_size) + 1 101 indices = np.random.permutation(np.arange(data_len)) 102 x_shuffle = x[indices] 103 y_shuffle = y[indices] 104 105 for i in range(num_batch): 106 start_id = i * batch_size 107 end_id = min((i + 1) * batch_size, data_len) 108 yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]
3.2 模型搭建cnn_model.py
定義訓練的參數,TextCNN()模型
1 # coding: utf-8 2 import tensorflow as tf 3 class TCNNConfig(object): 4 """CNN配置參數""" 5 embedding_dim = 64 # 詞向量維度 6 seq_length = 600 # 序列長度 7 num_classes = 10 # 類別數 8 num_filters = 256 # 卷積核數目 9 kernel_size = 5 # 卷積核尺寸 10 vocab_size = 5000 # 詞匯表達小 11 hidden_dim = 128 # 全連接層神經元 12 dropout_keep_prob = 0.5 # dropout保留比例 13 learning_rate = 1e-3 # 學習率 14 batch_size = 64 # 每批訓練大小 15 num_epochs = 10 # 總迭代輪次 16 print_per_batch = 100 # 每多少輪輸出一次結果 17 save_per_batch = 10 # 每多少輪存入tensorboard 18 19 class TextCNN(object): 20 """文本分類,CNN模型""" 21 def __init__(self, config): 22 self.config = config 23 # 三個待輸入的數據 24 self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x') 25 self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y') 26 self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') 27 self.cnn() 28 29 def cnn(self): 30 """CNN模型""" 31 # 詞向量映射 32 with tf.device('/cpu:0'): 33 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim]) 34 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x) 35 36 with tf.name_scope("cnn"): 37 # CNN layer 38 conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv') 39 # global max pooling layer 40 gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp') 41 42 with tf.name_scope("score"): 43 # 全連接層,后面接dropout以及relu激活 44 fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1') 45 fc = tf.contrib.layers.dropout(fc, self.keep_prob) 46 fc = tf.nn.relu(fc) 47 # 分類器 48 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2') 49 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1) # 預測類別 50 51 with tf.name_scope("optimize"): 52 # 損失函數,交叉熵 53 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y) 54 self.loss = tf.reduce_mean(cross_entropy) 55 # 優化器 56 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss) 57 58 with tf.name_scope("accuracy"): 59 # 准確率 60 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls) 61 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
3.3 運行代碼run_cnn.py
1 #!/usr/bin/python 2 # -*- coding: utf-8 -*- 3 from __future__ import print_function 4 import os 5 import sys 6 import time 7 from datetime import timedelta 8 import numpy as np 9 import tensorflow as tf 10 from sklearn import metrics 11 from cnn_model import TCNNConfig, TextCNN 12 from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab 13 14 base_dir = '../data/cnews' 15 train_dir = os.path.join(base_dir, 'cnews.train.txt') 16 test_dir = os.path.join(base_dir, 'cnews.test.txt') 17 val_dir = os.path.join(base_dir, 'cnews.val.txt') 18 vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 19 save_dir = 'checkpoints/textcnn' 20 save_path = os.path.join(save_dir, 'best_validation') # 最佳驗證結果保存路徑 21 22 def get_time_dif(start_time): 23 """獲取已使用時間""" 24 end_time = time.time() 25 time_dif = end_time - start_time 26 return timedelta(seconds=int(round(time_dif))) 27 28 def feed_data(x_batch, y_batch, keep_prob): 29 feed_dict = { 30 model.input_x: x_batch, 31 model.input_y: y_batch, 32 model.keep_prob: keep_prob 33 } 34 return feed_dict 35 36 def evaluate(sess, x_, y_): 37 """評估在某一數據上的准確率和損失""" 38 data_len = len(x_) 39 batch_eval = batch_iter(x_, y_, 128) 40 total_loss = 0.0 41 total_acc = 0.0 42 for x_batch, y_batch in batch_eval: 43 batch_len = len(x_batch) 44 feed_dict = feed_data(x_batch, y_batch, 1.0) 45 loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict) 46 total_loss += loss * batch_len 47 total_acc += acc * batch_len 48 return total_loss / data_len, total_acc / data_len 49 50 def train(): 51 print("Configuring TensorBoard and Saver...") 52 # 配置 Tensorboard,重新訓練時,請將tensorboard文件夾刪除,不然圖會覆蓋 53 tensorboard_dir = '../tensorboard/textcnn' 54 if not os.path.exists(tensorboard_dir): 55 os.makedirs(tensorboard_dir) 56 tf.summary.scalar("loss", model.loss) 57 tf.summary.scalar("accuracy", model.acc) 58 merged_summary = tf.summary.merge_all() 59 writer = tf.summary.FileWriter(tensorboard_dir) 60 61 # 配置 Saver 62 saver = tf.train.Saver() 63 if not os.path.exists(save_dir): 64 os.makedirs(save_dir) 65 66 print("Loading training and validation data...") 67 # 載入訓練集與驗證集 68 start_time = time.time() 69 x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length) 70 x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length) 71 time_dif = get_time_dif(start_time) 72 print("Time usage:", time_dif) 73 74 # 創建session 75 session = tf.Session() 76 session.run(tf.global_variables_initializer()) 77 writer.add_graph(session.graph) 78 79 print('Training and evaluating...') 80 start_time = time.time() 81 total_batch = 0 # 總批次 82 best_acc_val = 0.0 # 最佳驗證集准確率 83 last_improved = 0 # 記錄上一次提升批次 84 require_improvement = 1000 # 如果超過1000輪未提升,提前結束訓練 85 86 flag = False 87 for epoch in range(config.num_epochs): 88 print('Epoch:', epoch + 1) 89 batch_train = batch_iter(x_train, y_train, config.batch_size) 90 for x_batch, y_batch in batch_train: 91 feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob) 92 #print("x_batch is {}".format(x_batch.shape)) 93 if total_batch % config.save_per_batch == 0: 94 # 每多少輪次將訓練結果寫入tensorboard scalar 95 s = session.run(merged_summary, feed_dict=feed_dict) 96 writer.add_summary(s, total_batch) 97 if total_batch % config.print_per_batch == 0: 98 # 每多少輪次輸出在訓練集和驗證集上的性能 99 feed_dict[model.keep_prob] = 1.0 100 loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict) 101 loss_val, acc_val = evaluate(session, x_val, y_val) # todo 102 if acc_val > best_acc_val: 103 # 保存最好結果 104 best_acc_val = acc_val 105 last_improved = total_batch 106 saver.save(sess=session, save_path=save_path) 107 improved_str = '*' 108 else: 109 improved_str = '' 110 time_dif = get_time_dif(start_time) 111 msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \ 112 + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}' 113 print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str)) 114 115 session.run(model.optim, feed_dict=feed_dict) # 運行優化 116 total_batch += 1 117 118 if total_batch - last_improved > require_improvement: 119 # 驗證集正確率長期不提升,提前結束訓練 120 print("No optimization for a long time, auto-stopping...") 121 flag = True 122 break # 跳出循環 123 if flag: # 同上 124 break 125 126 def test(): 127 print("Loading test data...") 128 start_time = time.time() 129 x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length) 130 131 session = tf.Session() 132 session.run(tf.global_variables_initializer()) 133 saver = tf.train.Saver() 134 saver.restore(sess=session, save_path=save_path) # 讀取保存的模型 135 136 print('Testing...') 137 loss_test, acc_test = evaluate(session, x_test, y_test) 138 msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}' 139 print(msg.format(loss_test, acc_test)) 140 141 batch_size = 128 142 data_len = len(x_test) 143 num_batch = int((data_len - 1) / batch_size) + 1 144 145 y_test_cls = np.argmax(y_test, 1) 146 y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32) # 保存預測結果 147 for i in range(num_batch): # 逐批次處理 148 start_id = i * batch_size 149 end_id = min((i + 1) * batch_size, data_len) 150 feed_dict = { 151 model.input_x: x_test[start_id:end_id], 152 model.keep_prob: 1.0 153 } 154 y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict) 155 156 # 評估 157 print("Precision, Recall and F1-Score...") 158 print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories)) 159 160 # 混淆矩陣 161 print("Confusion Matrix...") 162 cm = metrics.confusion_matrix(y_test_cls, y_pred_cls) 163 print(cm) 164 165 time_dif = get_time_dif(start_time) 166 print("Time usage:" , time_dif) 167 168 if __name__ == '__main__': 169 170 config = TCNNConfig() 171 if not os.path.exists(vocab_dir): # 如果不存在詞匯表,重建,這里存在,因此不用重建 172 build_vocab(train_dir, vocab_dir, config.vocab_size) 173 categories, cat_to_id = read_category() 174 words, word_to_id = read_vocab(vocab_dir) 175 config.vocab_size = len(words) 176 model = TextCNN(config) 177 option='train' 178 if option == 'train': 179 train() 180 else: 181 test()
3.4 預測predict.py
1 # coding: utf-8 2 from __future__ import print_function 3 import os 4 import tensorflow as tf 5 import tensorflow.contrib.keras as kr 6 from cnn_model import TCNNConfig, TextCNN 7 from cnews_loader import read_category, read_vocab 8 try: 9 bool(type(unicode)) 10 except NameError: 11 unicode = str 12 13 base_dir = '../data/cnews' 14 vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt') 15 save_dir = '../checkpoints/textcnn' 16 save_path = os.path.join(save_dir, 'best_validation') # 最佳驗證結果保存路徑 17 18 class CnnModel: 19 def __init__(self): 20 self.config = TCNNConfig() 21 self.categories, self.cat_to_id = read_category() 22 self.words, self.word_to_id = read_vocab(vocab_dir) 23 self.config.vocab_size = len(self.words) 24 self.model = TextCNN(self.config) 25 self.session = tf.Session() 26 self.session.run(tf.global_variables_initializer()) 27 saver = tf.train.Saver() 28 saver.restore(sess=self.session, save_path=save_path) # 讀取保存的模型 29 30 def predict(self, message): 31 # 支持不論在python2還是python3下訓練的模型都可以在2或者3的環境下運行 32 content = unicode(message) 33 data = [self.word_to_id[x] for x in content if x in self.word_to_id] 34 35 feed_dict = { 36 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length), 37 self.model.keep_prob: 1.0 38 } 39 40 y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict) 41 return self.categories[y_pred_cls[0]] 42 43 if __name__ == '__main__': 44 cnn_model = CnnModel() 45 test_demo = ['三星ST550以全新的拍攝方式超越了以往任何一款數碼相機', 46 '熱火vs騎士前瞻:皇帝回鄉二番戰 東部次席唾手可得新浪體育訊北京時間3月30日7:00'] 47 for i in test_demo: 48 print(cnn_model.predict(i))
四、結果展示
相關代碼可見:https://github.com/yifanhunter/textClassifier_chinese