【文本分類-中文】textCNN


目錄

  1. 概述
  2. 數據集合
  3. 代碼
  4. 結果展示

一、概述

在英文分類的基礎上,再看看中文分類的,是一種10分類問題(體育,科技,游戲,財經,房產,家居等)的處理。

二、數據集合

數據集為新聞,總共有四個數據文件,在/data/cnews目錄下,包括內容如下圖所示測試集,訓練集和驗證集,和單詞表(最后的單詞表cnews.vocab.txt可以不要,因為訓練可以自動產生)。數據格式:前面為類別,后面為描述內容。

訓練數據地址:鏈接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取碼:2r04

其中訓練集的格式:

vocab.txt的格式:每個字一行,其中前面加上PAD。

三、代碼

3.1 數據采集cnews_loader.py

 

    1     # coding: utf-8
    2     import sys
    3     from collections import Counter
    4     import numpy as np
    5     import tensorflow.contrib.keras as kr
    6     
    7     if sys.version_info[0] > 2:
    8         is_py3 = True
    9     else:
   10         reload(sys)
   11         sys.setdefaultencoding("utf-8")
   12         is_py3 = False
   13     
   14     def native_word(word, encoding='utf-8'):
   15         """如果在python2下面使用python3訓練的模型,可考慮調用此函數轉化一下字符編碼"""
   16         if not is_py3:
   17             return word.encode(encoding)
   18         else:
   19             return word
   20     
   21     def native_content(content):
   22         if not is_py3:
   23             return content.decode('utf-8')
   24         else:
   25             return content
   26     
   27     def open_file(filename, mode='r'):
   28         """
   29         常用文件操作,可在python2和python3間切換.
   30         mode: 'r' or 'w' for read or write
   31         """
   32         if is_py3:
   33             return open(filename, mode, encoding='utf-8', errors='ignore')
   34         else:
   35             return open(filename, mode)
   36     
   37     def read_file(filename):
   38         """讀取文件數據"""
   39         contents, labels = [], []
   40         with open_file(filename) as f:
   41             for line in f:
   42                 try:
   43                     label, content = line.strip().split('\t')
   44                     if content:
   45                         contents.append(list(native_content(content)))
   46                         labels.append(native_content(label))
   47                 except:
   48                     pass
   49         return contents, labels
   50     
   51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
   52         """根據訓練集構建詞匯表,存儲"""
   53         data_train, _ = read_file(train_dir)
   54         all_data = []
   55         for content in data_train:
   56             all_data.extend(content)
   57         counter = Counter(all_data)
   58         count_pairs = counter.most_common(vocab_size - 1)
   59         words, _ = list(zip(*count_pairs))
   60         # 添加一個 <PAD> 來將所有文本pad為同一長度
   61         words = ['<PAD>'] + list(words)
   62         open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
   63     
   64     def read_vocab(vocab_dir):
   65         """讀取詞匯表"""
   66         # words = open_file(vocab_dir).read().strip().split('\n')
   67         with open_file(vocab_dir) as fp:
   68             # 如果是py2 則每個值都轉化為unicode
   69             words = [native_content(_.strip()) for _ in fp.readlines()]
   70         word_to_id = dict(zip(words, range(len(words))))
   71         return words, word_to_id
   72     
   73     def read_category():
   74         """讀取分類目錄,固定"""
   75         categories = ['體育', '財經', '房產', '家居', '教育', '科技', '時尚', '時政', '游戲', '娛樂']
   76         categories = [native_content(x) for x in categories]
   77         cat_to_id = dict(zip(categories, range(len(categories))))
   78         return categories, cat_to_id
   79     
   80     def to_words(content, words):
   81         """將id表示的內容轉換為文字"""
   82         return ''.join(words[x] for x in content)
   83     
   84     def process_file(filename, word_to_id, cat_to_id, max_length=600):
   85         """將文件轉換為id表示"""
   86         contents, labels = read_file(filename)
   87     
   88         data_id, label_id = [], []
   89         for i in range(len(contents)):
   90             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
   91             label_id.append(cat_to_id[labels[i]])
   92         # 使用keras提供的pad_sequences來將文本pad為固定長度
   93         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
   94         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 將標簽轉換為one-hot表示
   95         return x_pad, y_pad
   96     
   97     def batch_iter(x, y, batch_size=64):
   98         """生成批次數據"""
   99         data_len = len(x)
  100         num_batch = int((data_len - 1) / batch_size) + 1
  101         indices = np.random.permutation(np.arange(data_len))
  102         x_shuffle = x[indices]
  103         y_shuffle = y[indices]
  104     
  105         for i in range(num_batch):
  106             start_id = i * batch_size
  107             end_id = min((i + 1) * batch_size, data_len)
  108             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

 

3.2 模型搭建cnn_model.py

定義訓練的參數,TextCNN()模型

    1     # coding: utf-8
    2     import tensorflow as tf
    3     class TCNNConfig(object):
    4         """CNN配置參數"""
    5         embedding_dim = 64  # 詞向量維度
    6         seq_length = 600  # 序列長度
    7         num_classes = 10  # 類別數
    8         num_filters = 256  # 卷積核數目
    9         kernel_size = 5  # 卷積核尺寸
   10         vocab_size = 5000  # 詞匯表達小
   11         hidden_dim = 128  # 全連接層神經元
   12         dropout_keep_prob = 0.5  # dropout保留比例
   13         learning_rate = 1e-3  # 學習率
   14         batch_size = 64  # 每批訓練大小
   15         num_epochs = 10  # 總迭代輪次
   16         print_per_batch = 100  # 每多少輪輸出一次結果
   17         save_per_batch = 10  # 每多少輪存入tensorboard
   18     
   19     class TextCNN(object):
   20         """文本分類,CNN模型"""
   21         def __init__(self, config):
   22             self.config = config
   23             # 三個待輸入的數據
   24             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
   25             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
   26             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
   27             self.cnn()
   28     
   29         def cnn(self):
   30             """CNN模型"""
   31             # 詞向量映射
   32             with tf.device('/cpu:0'):
   33                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
   34                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
   35     
   36             with tf.name_scope("cnn"):
   37                 # CNN layer
   38                 conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
   39                 # global max pooling layer
   40                 gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
   41     
   42             with tf.name_scope("score"):
   43                 # 全連接層,后面接dropout以及relu激活
   44                 fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
   45                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
   46                 fc = tf.nn.relu(fc)
   47                 # 分類器
   48                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
   49                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 預測類別
   50     
   51             with tf.name_scope("optimize"):
   52                 # 損失函數,交叉熵
   53                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
   54                 self.loss = tf.reduce_mean(cross_entropy)
   55                 # 優化器
   56                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
   57     
   58             with tf.name_scope("accuracy"):
   59                 # 准確率
   60                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
   61                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

3.3 運行代碼run_cnn.py

  1 #!/usr/bin/python
  2 # -*- coding: utf-8 -*-
  3 from __future__ import print_function
  4 import os
  5 import sys
  6 import time
  7 from datetime import timedelta
  8 import numpy as np
  9 import tensorflow as tf
 10 from sklearn import metrics
 11 from cnn_model import TCNNConfig, TextCNN
 12 from  cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
 13 
 14 base_dir = '../data/cnews'
 15 train_dir = os.path.join(base_dir, 'cnews.train.txt')
 16 test_dir = os.path.join(base_dir, 'cnews.test.txt')
 17 val_dir = os.path.join(base_dir, 'cnews.val.txt')
 18 vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 19 save_dir = 'checkpoints/textcnn'
 20 save_path = os.path.join(save_dir, 'best_validation')  # 最佳驗證結果保存路徑
 21 
 22 def get_time_dif(start_time):
 23     """獲取已使用時間"""
 24     end_time = time.time()
 25     time_dif = end_time - start_time
 26     return timedelta(seconds=int(round(time_dif)))
 27 
 28 def feed_data(x_batch, y_batch, keep_prob):
 29     feed_dict = {
 30         model.input_x: x_batch,
 31         model.input_y: y_batch,
 32         model.keep_prob: keep_prob
 33     }
 34     return feed_dict
 35 
 36 def evaluate(sess, x_, y_):
 37     """評估在某一數據上的准確率和損失"""
 38     data_len = len(x_)
 39     batch_eval = batch_iter(x_, y_, 128)
 40     total_loss = 0.0
 41     total_acc = 0.0
 42     for x_batch, y_batch in batch_eval:
 43         batch_len = len(x_batch)
 44         feed_dict = feed_data(x_batch, y_batch, 1.0)
 45         loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
 46         total_loss += loss * batch_len
 47         total_acc += acc * batch_len
 48     return total_loss / data_len, total_acc / data_len
 49 
 50 def train():
 51     print("Configuring TensorBoard and Saver...")
 52     # 配置 Tensorboard,重新訓練時,請將tensorboard文件夾刪除,不然圖會覆蓋
 53     tensorboard_dir = '../tensorboard/textcnn'
 54     if not os.path.exists(tensorboard_dir):
 55         os.makedirs(tensorboard_dir)
 56     tf.summary.scalar("loss", model.loss)
 57     tf.summary.scalar("accuracy", model.acc)
 58     merged_summary = tf.summary.merge_all()
 59     writer = tf.summary.FileWriter(tensorboard_dir)
 60 
 61     # 配置 Saver
 62     saver = tf.train.Saver()
 63     if not os.path.exists(save_dir):
 64         os.makedirs(save_dir)
 65 
 66     print("Loading training and validation data...")
 67     # 載入訓練集與驗證集
 68     start_time = time.time()
 69     x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
 70     x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
 71     time_dif = get_time_dif(start_time)
 72     print("Time usage:", time_dif)
 73 
 74     # 創建session
 75     session = tf.Session()
 76     session.run(tf.global_variables_initializer())
 77     writer.add_graph(session.graph)
 78 
 79     print('Training and evaluating...')
 80     start_time = time.time()
 81     total_batch = 0  # 總批次
 82     best_acc_val = 0.0  # 最佳驗證集准確率
 83     last_improved = 0  # 記錄上一次提升批次
 84     require_improvement = 1000  # 如果超過1000輪未提升,提前結束訓練
 85 
 86     flag = False
 87     for epoch in range(config.num_epochs):
 88         print('Epoch:', epoch + 1)
 89         batch_train = batch_iter(x_train, y_train, config.batch_size)
 90         for x_batch, y_batch in batch_train:
 91             feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
 92             #print("x_batch is {}".format(x_batch.shape))
 93             if total_batch % config.save_per_batch == 0:
 94                 # 每多少輪次將訓練結果寫入tensorboard scalar
 95                 s = session.run(merged_summary, feed_dict=feed_dict)
 96                 writer.add_summary(s, total_batch)
 97             if total_batch % config.print_per_batch == 0:
 98                 # 每多少輪次輸出在訓練集和驗證集上的性能
 99                 feed_dict[model.keep_prob] = 1.0
100                 loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
101                 loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
102                 if acc_val > best_acc_val:
103                     # 保存最好結果
104                     best_acc_val = acc_val
105                     last_improved = total_batch
106                     saver.save(sess=session, save_path=save_path)
107                     improved_str = '*'
108                 else:
109                     improved_str = ''
110                 time_dif = get_time_dif(start_time)
111                 msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
112                       + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
113                 print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
114 
115             session.run(model.optim, feed_dict=feed_dict)  # 運行優化
116             total_batch += 1
117 
118             if total_batch - last_improved > require_improvement:
119                 # 驗證集正確率長期不提升,提前結束訓練
120                 print("No optimization for a long time, auto-stopping...")
121                 flag = True
122                 break  # 跳出循環
123         if flag:  # 同上
124             break
125 
126 def test():
127     print("Loading test data...")
128     start_time = time.time()
129     x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
130 
131     session = tf.Session()
132     session.run(tf.global_variables_initializer())
133     saver = tf.train.Saver()
134     saver.restore(sess=session, save_path=save_path)  # 讀取保存的模型
135 
136     print('Testing...')
137     loss_test, acc_test = evaluate(session, x_test, y_test)
138     msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
139     print(msg.format(loss_test, acc_test))
140 
141     batch_size = 128
142     data_len = len(x_test)
143     num_batch = int((data_len - 1) / batch_size) + 1
144 
145     y_test_cls = np.argmax(y_test, 1)
146     y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存預測結果
147     for i in range(num_batch):  # 逐批次處理
148         start_id = i * batch_size
149         end_id = min((i + 1) * batch_size, data_len)
150         feed_dict = {
151             model.input_x: x_test[start_id:end_id],
152             model.keep_prob: 1.0
153         }
154         y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
155 
156     # 評估
157     print("Precision, Recall and F1-Score...")
158     print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
159 
160     # 混淆矩陣
161     print("Confusion Matrix...")
162     cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
163     print(cm)
164 
165     time_dif = get_time_dif(start_time)
166     print("Time usage:" , time_dif)
167 
168 if __name__ == '__main__':
169     
170     config = TCNNConfig()
171     if not os.path.exists(vocab_dir):  # 如果不存在詞匯表,重建,這里存在,因此不用重建
172         build_vocab(train_dir, vocab_dir, config.vocab_size)
173     categories, cat_to_id = read_category()
174     words, word_to_id = read_vocab(vocab_dir)
175     config.vocab_size = len(words)
176     model = TextCNN(config)
177     option='train'
178     if option == 'train':
179         train()
180     else:
181         test()

3.4 預測predict.py

    1     # coding: utf-8
    2     from __future__ import print_function
    3     import os
    4     import tensorflow as tf
    5     import tensorflow.contrib.keras as kr
    6     from cnn_model import TCNNConfig, TextCNN
    7     from cnews_loader import read_category, read_vocab
    8     try:
    9         bool(type(unicode))
   10     except NameError:
   11         unicode = str
   12     
   13     base_dir = '../data/cnews'
   14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
   15     save_dir = '../checkpoints/textcnn'
   16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳驗證結果保存路徑
   17     
   18     class CnnModel:
   19         def __init__(self):
   20             self.config = TCNNConfig()
   21             self.categories, self.cat_to_id = read_category()
   22             self.words, self.word_to_id = read_vocab(vocab_dir)
   23             self.config.vocab_size = len(self.words)
   24             self.model = TextCNN(self.config)
   25             self.session = tf.Session()
   26             self.session.run(tf.global_variables_initializer())
   27             saver = tf.train.Saver()
   28             saver.restore(sess=self.session, save_path=save_path)  # 讀取保存的模型
   29     
   30         def predict(self, message):
   31             # 支持不論在python2還是python3下訓練的模型都可以在2或者3的環境下運行
   32             content = unicode(message)
   33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
   34     
   35             feed_dict = {
   36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
   37                 self.model.keep_prob: 1.0
   38             }
   39     
   40             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
   41             return self.categories[y_pred_cls[0]]
   42     
   43     if __name__ == '__main__':
   44         cnn_model = CnnModel()
   45         test_demo = ['三星ST550以全新的拍攝方式超越了以往任何一款數碼相機',
   46                      '熱火vs騎士前瞻:皇帝回鄉二番戰 東部次席唾手可得新浪體育訊北京時間3月30日7:00']
   47         for i in test_demo:
   48             print(cnn_model.predict(i))

四、結果展示

   

   

相關代碼可見:https://github.com/yifanhunter/textClassifier_chinese

   


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM