深度學習之新聞多分類問題


平時除了遇到二分類問題,碰到最多的就是多分類問題,例如我們發布blogs時候選擇的tag等。如果每個樣本只關聯一個標簽則是單標簽多分類,如果每個樣本可以關聯多個樣本,則是多標簽多分類。今天我們來看下新聞的多分類問題。

一、數據集

這里使用路透社在1986年發布的數據集,它包含很多的短新聞及其對應的主題,它包含46個主題,是一個簡單的被廣泛使用的分類數據集。

    def load_data(self):
        return reuters.load_data(num_words=self.num_words)
        
    
    (train_data, train_labels), (test_data, test_labels) = self.load_data()
        print(len(train_data))
        print(len(test_data))
        print(train_data[0])
        print(train_labels[0])

可以看到有8982個訓練樣本及2246個測試樣本,同時也可以看到第一個訓練樣本的內容和標簽都是數字。

8982
2246
[1, 2, 2, 8, 43, 10, 447, 5, 25, 207, 270, 5, 3095, 111, 16, 369, 186, 90, 67, 7, 89, 5, 19, 102, 6, 19, 124, 15, 90, 67, 84, 22, 482, 26, 7, 48, 4, 49, 8, 864, 39, 209, 154, 6, 151, 6, 83, 11, 15, 22, 155, 11, 15, 7, 48, 9, 4579, 1005, 504, 6, 258, 6, 272, 11, 15, 22, 134, 44, 11, 15, 16, 8, 197, 1245, 90, 67, 52, 29, 209, 30, 32, 132, 6, 109, 15, 17, 12]
3

看下第一個訓練樣本的實際內容

    def get_text(self, data):
        word_id_index = reuters.get_word_index()
        id_word_index = dict([(id, value) for (value, id) in word_id_index.items()])
        return ' '.join([id_word_index.get(i - 3, '?') for i in data])
        
    
    print(self.get_text(train_data[0]))

執行后的樣本內容

? ? ? said as a result of its december acquisition of space co it expects earnings per share in 1987 of 1 15 to 1 30 dlrs per share up from 70 cts in 1986 the company said pretax net should rise to nine to 10 mln dlrs from six mln dlrs in 1986 and rental operation revenues to 19 to 22 mln dlrs from 12 5 mln dlrs it said cash flow per share this year should be 2 50 to three dlrs reuter 3

二、數據格式化

使用one-hot方式編碼訓練數據

    def vectorize_sequences(self, sequences, dimension=10000):
        results = np.zeros((len(sequences), dimension))
        for i,sequence in enumerate(sequences):
            results[i, sequence] = 1.
        return results
    
    self.x_train = x_train = self.vectorize_sequences(train_data)
    self.x_test = x_test = self.vectorize_sequences(test_data)

編碼標簽數據

    def to_one_hot(self, labels, dimension=46):
        results = np.zeros((len(labels), dimension))
        for i,label in enumerate(labels):
            results[i, label] = 1
        return results
        
    self.one_hot_train_labels = one_hot_train_labels = self.to_one_hot(train_labels)
    self.one_hot_test_labels = one_hot_test_labels = self.to_one_hot(test_labels)    

三、構建模型

這里有46個新聞類別,所以中間層的維度不能太少,否則丟失的信息太多,這里我們使用64個隱藏單元。

        model = self.model = models.Sequential()
        model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(46, activation='softmax'))
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics='accuracy')

最后一層輸出是46個維度的向量,每個維度代碼樣本屬於對應分類的概率。
這里使用便於計算兩個概率分布距離的分類交叉熵作為損失函數。

四、校驗模型

從訓練集中保留一部分作為校驗數據集。

        x_val = x_train[:1000]
        partial_x_train = x_train[1000:]

        y_val = one_hot_train_labels[:1000]
        partial_y_train = one_hot_train_labels[1000:]

還是以512個樣本作為一個小的批次,訓練20輪。

        history = model.fit(partial_x_train, partial_y_train, epochs=self.epochs, batch_size=512, validation_data=(x_val, y_val))

繪制損失曲線圖

    def plt_loss(self, history):
        plt.clf()
        loss = history.histroy['loss']
        val_loss = history.histroy['val_loss']
        epochs = range(1, len(loss) + 1)
        plt.plot(epochs, loss, 'bo', label='Training loss')
        plt.plot(epochs, val_loss, 'b', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

繪制准確度曲線

    def plt_accuracy(self, history):
        plt.clf()
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']
        epochs = range(1, len(acc) + 1)

        plt.plot(epochs, acc, 'bo', label='Training accuracy')
        plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()

從圖中可以看到訓練到第九輪之后開始出現過擬合,改為9輪進行訓練模型,並在測試機上評估模型。

    def evaluate(self):
        results = self.model.evaluate(self.x_test, self.one_hot_test_labels)
        print('evaluate test data:')
        print(results)

最終訓練之后精度可以達到79%。

evaluate test data:
[0.9847680330276489, 0.7925200462341309]

五、總結

  • 網絡最后一層的大小應該跟類別的數量保持一致;
  • 單標簽多分類問題,最后一層需要使用softmax激活函數,方便輸出概率分布。
  • 單標簽多分類問題,需要使用分類交叉熵作為損失函數。
  • 中間層的維度不能小於輸出標簽數量。

完整源代碼

from tensorflow.keras.datasets import reuters
import numpy as np
from tensorflow.keras import models
from tensorflow.keras import layers
import matplotlib.pyplot as plt


class MultiClassifier:

    def __init__(self, num_words, epochs):
        self.num_words = num_words
        self.epochs = epochs
        self.model = None
        self.eval = False if epochs == 20 else True

    def load_data(self):
        return reuters.load_data(num_words=self.num_words)

    def get_text(self, data):
        word_id_index = reuters.get_word_index()
        id_word_index = dict([(id, value) for (value, id) in word_id_index.items()])
        return ' '.join([id_word_index.get(i - 3, '?') for i in data])

    def vectorize_sequences(self, sequences, dimension=10000):
        results = np.zeros((len(sequences), dimension))
        for i,sequence in enumerate(sequences):
            results[i, sequence] = 1.
        return results

    def to_one_hot(self, labels, dimension=46):
        results = np.zeros((len(labels), dimension))
        for i,label in enumerate(labels):
            results[i, label] = 1
        return results

    def plt_loss(self, history):
        plt.clf()
        loss = history.history['loss']
        val_loss = history.history['val_loss']
        epochs = range(1, len(loss) + 1)
        plt.plot(epochs, loss, 'bo', label='Training loss')
        plt.plot(epochs, val_loss, 'b', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    def plt_accuracy(self, history):
        plt.clf()
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']
        epochs = range(1, len(acc) + 1)

        plt.plot(epochs, acc, 'bo', label='Training accuracy')
        plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()

    def evaluate(self):
        results = self.model.evaluate(self.x_test, self.one_hot_test_labels)
        print('evaluate test data:')
        print(results)


    def train(self):
        (train_data, train_labels), (test_data, test_labels) = self.load_data()
        print(len(train_data))
        print(len(test_data))
        print(train_data[0])
        print(train_labels[0])
        print(self.get_text(train_data[0]))

        self.x_train = x_train = self.vectorize_sequences(train_data)
        self.x_test = x_test = self.vectorize_sequences(test_data)

        self.one_hot_train_labels = one_hot_train_labels = self.to_one_hot(train_labels)
        self.one_hot_test_labels = one_hot_test_labels = self.to_one_hot(test_labels)

        model = self.model = models.Sequential()
        model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(46, activation='softmax'))
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics='accuracy')

        x_val = x_train[:1000]
        partial_x_train = x_train[1000:]

        y_val = one_hot_train_labels[:1000]
        partial_y_train = one_hot_train_labels[1000:]

        history = model.fit(partial_x_train, partial_y_train, epochs=self.epochs, batch_size=512, validation_data=(x_val, y_val))



        if self.eval:
            self.evaluate()
            print(self.model.predict(x_test))
        else:
            self.plt_loss(history)
            self.plt_accuracy(history)

classifier = MultiClassifier(num_words=10000, epochs=20)

# classifier = MultiClassifier(num_words=10000, epochs=9)
classifier.train()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM