【文本分類-中文】textRNN


一、概述

在英文分類的基礎上,再看看中文分類的,是一種10分類問題(體育,科技,游戲,財經,房產,家居等)的處理。

二、數據集合

數據集為新聞,總共有四個數據文件,在/data/cnews目錄下,包括內容如下圖所示測試集,訓練集和驗證集,和單詞表(最后的單詞表cnews.vocab.txt可以不要,因為訓練可以自動產生)。數據格式:前面為類別,后面為描述內容。

訓練數據地址:鏈接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取碼:2r04

其中訓練集的格式:

vocab.txt的格式:每個字一行,其中前面加上PAD。

三、代碼

3.1 數據采集cnews_loader.py

    1     # coding: utf-8
    2     import sys
    3     from collections import Counter
    4     import numpy as np
    5     import tensorflow.contrib.keras as kr
    6     
    7     if sys.version_info[0] > 2:
    8         is_py3 = True
    9     else:
   10         reload(sys)
   11         sys.setdefaultencoding("utf-8")
   12         is_py3 = False
   13     
   14     def native_word(word, encoding='utf-8'):
   15         """如果在python2下面使用python3訓練的模型,可考慮調用此函數轉化一下字符編碼"""
   16         if not is_py3:
   17             return word.encode(encoding)
   18         else:
   19             return word
   20     
   21     def native_content(content):
   22         if not is_py3:
   23             return content.decode('utf-8')
   24         else:
   25             return content
   26     
   27     def open_file(filename, mode='r'):
   28         """
   29         常用文件操作,可在python2和python3間切換.
   30         mode: 'r' or 'w' for read or write
   31         """
   32         if is_py3:
   33             return open(filename, mode, encoding='utf-8', errors='ignore')
   34         else:
   35             return open(filename, mode)
   36     
   37     def read_file(filename):
   38         """讀取文件數據"""
   39         contents, labels = [], []
   40         with open_file(filename) as f:
   41             for line in f:
   42                 try:
   43                     label, content = line.strip().split('\t')
   44                     if content:
   45                         contents.append(list(native_content(content)))
   46                         labels.append(native_content(label))
   47                 except:
   48                     pass
   49         return contents, labels
   50     
   51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
   52         """根據訓練集構建詞匯表,存儲"""
   53         data_train, _ = read_file(train_dir)
   54     
   55         all_data = []
   56         for content in data_train:
   57             all_data.extend(content)
   58         counter = Counter(all_data)
   59         count_pairs = counter.most_common(vocab_size - 1)
   60         words, _ = list(zip(*count_pairs))
   61         # 添加一個 <PAD> 來將所有文本pad為同一長度
   62         words = ['<PAD>'] + list(words)
   63         open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
   64     
   65     def read_vocab(vocab_dir):
   66         """讀取詞匯表"""
   67         # words = open_file(vocab_dir).read().strip().split('\n')
   68         with open_file(vocab_dir) as fp:
   69             # 如果是py2 則每個值都轉化為unicode
   70             words = [native_content(_.strip()) for _ in fp.readlines()]
   71         word_to_id = dict(zip(words, range(len(words))))
   72         return words, word_to_id
   73     
   74     def read_category():
   75         """讀取分類目錄,固定"""
   76         categories = ['體育', '財經', '房產', '家居', '教育', '科技', '時尚', '時政', '游戲', '娛樂']
   77         categories = [native_content(x) for x in categories]
   78         cat_to_id = dict(zip(categories, range(len(categories))))
   79         return categories, cat_to_id
   80     
   81     def to_words(content, words):
   82         """將id表示的內容轉換為文字"""
   83         return ''.join(words[x] for x in content)
   84     
   85     def process_file(filename, word_to_id, cat_to_id, max_length=600):
   86         """將文件轉換為id表示"""
   87         contents, labels = read_file(filename)
   88     
   89         data_id, label_id = [], []
   90         for i in range(len(contents)):
   91             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
   92             label_id.append(cat_to_id[labels[i]])
   93         # 使用keras提供的pad_sequences來將文本pad為固定長度
   94         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
   95         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 將標簽轉換為one-hot表示
   96     
   97         return x_pad, y_pad
   98     
   99     def batch_iter(x, y, batch_size=64):
  100         """生成批次數據"""
  101         data_len = len(x)
  102         num_batch = int((data_len - 1) / batch_size) + 1
  103     
  104         indices = np.random.permutation(np.arange(data_len))
  105         x_shuffle = x[indices]
  106         y_shuffle = y[indices]
  107     
  108         for i in range(num_batch):
  109             start_id = i * batch_size
  110             end_id = min((i + 1) * batch_size, data_len)
  111             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id

 

3.2 模型搭建cnn_model.py

    1     #!/usr/bin/python
    2     # -*- coding: utf-8 -*-
    3     import tensorflow as tf
    4     
    5     class TRNNConfig(object):
    6         """RNN配置參數"""
    7         # 模型參數
    8         embedding_dim = 64      # 詞向量維度
    9         seq_length = 600        # 序列長度
   10         num_classes = 10        # 類別數
   11         vocab_size = 5000       # 詞匯表達小
   12         num_layers= 2           # 隱藏層層數
   13         hidden_dim = 128        # 隱藏層神經元
   14         rnn = 'gru'             # lstm 或 gru
   15         dropout_keep_prob = 0.8 # dropout保留比例
   16         learning_rate = 1e-3    # 學習率
   17         batch_size = 128         # 每批訓練大小
   18         num_epochs = 10          # 總迭代輪次
   19         print_per_batch = 100    # 每多少輪輸出一次結果
   20         save_per_batch = 10      # 每多少輪存入tensorboard
   21     
   22     class TextRNN(object):
   23         """文本分類,RNN模型"""
   24         def __init__(self, config):
   25             self.config = config
   26             # 三個待輸入的數據
   27             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
   28             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
   29             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
   30             self.rnn()
   31     
   32         def rnn(self):
   33             """rnn模型"""
   34             def lstm_cell():   # lstm核
   35                 return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)
   36             def gru_cell():  # gru核
   37                 return tf.contrib.rnn.GRUCell(self.config.hidden_dim)
   38             def dropout(): # 為每一個rnn核后面加一個dropout層
   39                 if (self.config.rnn == 'lstm'):
   40                     cell = lstm_cell()
   41                 else:
   42                     cell = gru_cell()
   43                 return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
   44     
   45             # 詞向量映射
   46             with tf.device('/cpu:0'):
   47                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
   48                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
   49     
   50             with tf.name_scope("rnn"):
   51                 # 多層rnn網絡
   52                 cells = [dropout() for _ in range(self.config.num_layers)]
   53                 rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
   54     
   55                 _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
   56                 last = _outputs[:, -1, :]  # 取最后一個時序輸出作為結果
   57     
   58             with tf.name_scope("score"):
   59                 # 全連接層,后面接dropout以及relu激活
   60                 fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
   61                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
   62                 fc = tf.nn.relu(fc)
   63                 # 分類器
   64                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
   65                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 預測類別
   66     
   67             with tf.name_scope("optimize"):
   68                 # 損失函數,交叉熵
   69                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
   70                 self.loss = tf.reduce_mean(cross_entropy)
   71                 # 優化器
   72                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
   73     
   74             with tf.name_scope("accuracy"):
   75                 # 准確率
   76                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
   77                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

3.3 運行代碼run_cnn.py

    1     # coding: utf-8
    2     from __future__ import print_function
    3     import os
    4     import sys
    5     import time
    6     from datetime import timedelta
    7     import numpy as np
    8     import tensorflow as tf
    9     from sklearn import metrics
   10     from rnn_model import TRNNConfig, TextRNN
   11     from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
   12     
   13     base_dir = '../data/cnews'
   14     train_dir = os.path.join(base_dir, 'cnews.train.txt')
   15     test_dir = os.path.join(base_dir, 'cnews.test.txt')
   16     val_dir = os.path.join(base_dir, 'cnews.val.txt')
   17     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
   18     save_dir = '../checkpoints/textrnn'
   19     save_path = os.path.join(save_dir, 'best_validation')  # 最佳驗證結果保存路徑
   20     
   21     def get_time_dif(start_time):
   22         """獲取已使用時間"""
   23         end_time = time.time()
   24         time_dif = end_time - start_time
   25         return timedelta(seconds=int(round(time_dif)))
   26     
   27     def feed_data(x_batch, y_batch, keep_prob):
   28         feed_dict = {
   29             model.input_x: x_batch,
   30             model.input_y: y_batch,
   31             model.keep_prob: keep_prob
   32         }
   33         return feed_dict
   34     
   35     def evaluate(sess, x_, y_):
   36         """評估在某一數據上的准確率和損失"""
   37         data_len = len(x_)
   38         batch_eval = batch_iter(x_, y_, 128)
   39         total_loss = 0.0
   40         total_acc = 0.0
   41         for x_batch, y_batch in batch_eval:
   42             batch_len = len(x_batch)
   43             feed_dict = feed_data(x_batch, y_batch, 1.0)
   44             y_pred_class,loss, acc = sess.run([model.y_pred_cls,model.loss, model.acc], feed_dict=feed_dict)
   45             total_loss += loss * batch_len
   46             total_acc += acc * batch_len
   47         return y_pred_class,total_loss / data_len, total_acc / data_len
   48     
   49     def train():
   50         print("Configuring TensorBoard and Saver...")
   51         # 配置 Tensorboard,重新訓練時,請將tensorboard文件夾刪除,不然圖會覆蓋
   52         tensorboard_dir = '../tensorboard/textrnn'
   53         if not os.path.exists(tensorboard_dir):
   54             os.makedirs(tensorboard_dir)
   55         tf.summary.scalar("loss", model.loss)
   56         tf.summary.scalar("accuracy", model.acc)
   57         merged_summary = tf.summary.merge_all()
   58         writer = tf.summary.FileWriter(tensorboard_dir)
   59         # 配置 Saver
   60         saver = tf.train.Saver()
   61         if not os.path.exists(save_dir):
   62             os.makedirs(save_dir)
   63     
   64         print("Loading training and validation data...")
   65         # 載入訓練集與驗證集
   66         start_time = time.time()
   67         x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
   68         x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
   69         time_dif = get_time_dif(start_time)
   70         print("Time usage:", time_dif)
   71     
   72         # 創建session
   73         session = tf.Session()
   74         session.run(tf.global_variables_initializer())
   75         writer.add_graph(session.graph)
   76         print('Training and evaluating...')
   77         start_time = time.time()
   78         total_batch = 0  # 總批次
   79         best_acc_val = 0.0  # 最佳驗證集准確率
   80         last_improved = 0  # 記錄上一次提升批次
   81         require_improvement = 1000  # 如果超過1000輪未提升,提前結束訓練
   82     
   83         flag = False
   84         for epoch in range(config.num_epochs):
   85             print('Epoch:', epoch + 1)
   86             batch_train = batch_iter(x_train, y_train, config.batch_size)
   87             for x_batch, y_batch in batch_train:
   88                 feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
   89     
   90                 if total_batch % config.save_per_batch == 0:
   91                     # 每多少輪次將訓練結果寫入tensorboard scalar
   92                     s = session.run(merged_summary, feed_dict=feed_dict)
   93                     writer.add_summary(s, total_batch)
   94     
   95                 if total_batch % config.print_per_batch == 0:
   96                     # 每多少輪次輸出在訓練集和驗證集上的性能
   97                     feed_dict[model.keep_prob] = 1.0
   98                     loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
   99                     y_pred_class,loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
  100     
  101                     if acc_val > best_acc_val:
  102                         # 保存最好結果
  103                         best_acc_val = acc_val
  104                         last_improved = total_batch
  105                         saver.save(sess=session, save_path=save_path)
  106                         improved_str = '*'
  107                     else:
  108                         improved_str = ''
  109     
  110                     time_dif = get_time_dif(start_time)
  111                     msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
  112                           + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
  113                     print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
  114     
  115                 session.run(model.optim, feed_dict=feed_dict)  # 運行優化
  116                 total_batch += 1
  117     
  118                 if total_batch - last_improved > require_improvement:
  119                     # 驗證集正確率長期不提升,提前結束訓練
  120                     print("No optimization for a long time, auto-stopping...")
  121                     flag = True
  122                     break  # 跳出循環
  123             if flag:  # 同上
  124                 break
  125     
  126     def test():
  127         print("Loading test data...")
  128         start_time = time.time()
  129         x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
  130         session = tf.Session()
  131         session.run(tf.global_variables_initializer())
  132         saver = tf.train.Saver()
  133         saver.restore(sess=session, save_path=save_path)  # 讀取保存的模型
  134     
  135         print('Testing...')
  136         y_pred,loss_test, acc_test = evaluate(session, x_test, y_test)
  137         msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
  138         print(msg.format(loss_test, acc_test))
  139     
  140         batch_size = 128
  141         data_len = len(x_test)
  142         num_batch = int((data_len - 1) / batch_size) + 1
  143     
  144         y_test_cls = np.argmax(y_test, 1)
  145         y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存預測結果
  146         for i in range(num_batch):  # 逐批次處理
  147             start_id = i * batch_size
  148             end_id = min((i + 1) * batch_size, data_len)
  149             feed_dict = {
  150                 model.input_x: x_test[start_id:end_id],
  151                 model.keep_prob: 1.0
  152             }
  153             y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
  154     
  155         # 評估
  156         print("Precision, Recall and F1-Score...")
  157         print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
  158         # 混淆矩陣
  159         print("Confusion Matrix...")
  160         cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
  161         print(cm)
  162         time_dif = get_time_dif(start_time)
  163         print("Time usage:", time_dif)
  164     
  165     if __name__ == '__main__':
  166         print('Configuring RNN model...')
  167         config = TRNNConfig()
  168         if not os.path.exists(vocab_dir):  # 如果不存在詞匯表,重建
  169             build_vocab(train_dir, vocab_dir, config.vocab_size)
  170         categories, cat_to_id = read_category()
  171         words, word_to_id = read_vocab(vocab_dir)
  172         config.vocab_size = len(words)
  173         model = TextRNN(config)
  174         option='train'
  175         if option == 'train':
  176             train()
  177         else:
  178             test()

3.4 預測predict.py

    1     # coding: utf-8
    2     from __future__ import print_function
    3     import os
    4     import tensorflow as tf
    5     import tensorflow.contrib.keras as kr
    6     from rnn_model import TRNNConfig, TextRNN
    7     from  cnews_loader import read_category, read_vocab
    8     try:
    9         bool(type(unicode))
   10     except NameError:
   11         unicode = str
   12     
   13     base_dir = '../data/cnews'
   14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
   15     save_dir = '../checkpoints/textrnn'
   16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳驗證結果保存路徑
   17     
   18     class RnnModel:
   19         def __init__(self):
   20             self.config = TRNNConfig()
   21             self.categories, self.cat_to_id = read_category()
   22             self.words, self.word_to_id = read_vocab(vocab_dir)
   23             self.config.vocab_size = len(self.words)
   24             self.model = TextRNN(self.config)
   25             self.session = tf.Session()
   26             self.session.run(tf.global_variables_initializer())
   27             saver = tf.train.Saver()
   28             saver.restore(sess=self.session, save_path=save_path)  # 讀取保存的模型
   29     
   30         def predict(self, message):
   31             # 支持不論在python2還是python3下訓練的模型都可以在2或者3的環境下運行
   32             content = unicode(message)
   33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
   34     
   35             feed_dict = {
   36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
   37                 self.model.keep_prob: 1.0
   38             }
   39             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
   40             return self.categories[y_pred_cls[0]]
   41     
   42     if __name__ == '__main__':
   43         rnn_model = RnnModel()
   44         test_demo = ['三星ST550以全新的拍攝方式超越了以往任何一款數碼相機',
   45                      '熱火vs騎士前瞻:皇帝回鄉二番戰 東部次席唾手可得新浪體育訊北京時間3月30日7:00']
   46         for i in test_demo:
   47             print(rnn_model.predict(i))

四、結果展示

訓練時長,接近2小時

   

相關代碼可見:https://github.com/yifanhunter/textClassifier_chinese


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM