基於tf-idf的文本分類預測模型


為以后項目准備,在此寫一下文本分類預測模型的完整流程,使用的多項式朴素貝葉斯算法進行預測,在其他人項目中看到使用前饋神經網絡進行預測(本人目前沒有使用過深度學習進行文本分類,不知道效果怎么樣)

目前有2個問題未解決

  1. 模型建立完,怎樣預測一個新的文本文件(詞頻向量化無法處理)?
    解決方案:目前使用通過測試集和訓練集建的詞袋模型進行新文本的詞頻向量化,然后使用算法模型進行文本預測)
  2. 繪制PR曲線和ROC曲線
    解決方案:目前還沒有明白怎樣繪制,誰有好的方法求告知。
import os
import pandas as pd
import jieba
from sklearn.preprocessing import LabelEncoder  # 標簽編碼
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer  # 能夠完成詞頻向量化+去除停用詞
from sklearn.naive_bayes import MultinomialNB  # 多項式朴素貝葉斯
from sklearn.metrics import recall_score, precision_score, accuracy_score, f1_score  # 評估指標召回率,精確率, 准確率,f1[0,1]


# import xgboost

# 1.數據讀取--獲取數據
def read_text(base_dir_path):
    '''
    :param base_dir_path: 包含所有文件的路徑
    :return: dataFrame
    '''
    text_list = []
    # base_dir_path = './text_classification-master/text classification/train'
    list_dirs = os.listdir(base_dir_path)
    for i in list_dirs:
        file_path = os.path.join(base_dir_path, i)
        for j in os.listdir(file_path):
            try:
                with open(file_path + '/' + j) as f:
                    text_list.append([f.read(), i])
            except Exception as error:
                print(f'{file_path + "/" + j}文件讀取失敗')
                print(error)
    return pd.DataFrame(text_list, columns=['text', 'label'])


# 2.分詞------數據分析與處理

def cut_word(text):
    '''
    :param text: 待處理的文本序列
    :return: 處理后的文本序列
    '''
    return [' '.join(jieba.cut(i)) for i in text]


# 3. 停用詞
stopword = [i.strip() for i in
            open('./text_classification-master/text classification/stop/stopword.txt', encoding='utf-8').readlines()]


# 4.編碼器處理文本標簽--數據分析與處理
def label_encode(label):
    '''
    :param label: 待處理的標簽序列
    :return: 處理后的標簽序列
    '''
    le = LabelEncoder()
    e = le.fit_transform(label)
    print(dict(list(enumerate(le.classes_))))
    return e


if __name__ == '__main__':
    train = read_text('./text_classification-master/text classification/train')
    test = read_text('./text_classification-master/text classification/test')
    # print(train)
    # print(test)
    train['text'] = cut_word(train['text'])
    test['text'] = cut_word(test['text'])
    print(train)
    print(test)
    train['label_'] = label_encode(train['label'])
    test['label_'] = label_encode(test['label'])
    print(train)
    print(test)

    # 5.詞頻向量化---特征工程與選擇
    # 5.1使用tf - idf處理數據 -------------------------使用測試集評分0.8左右
    tf_idf = TfidfVectorizer(stop_words=stopword)  # 停用詞處理
    tf_idf.fit(list(train['text']) + list(test['text']))
    train_x = tf_idf.transform(train['text']).toarray()
    test_x = tf_idf.transform(test['text']).toarray()
    # print(train_x)
    # print(test_x)
    # 5.2 使用onehot_encode處理-------------------------使用測試集評分0.9左右(不知道為啥比tf-idf效果好)
    # counter = CountVectorizer(stop_words=stopword)
    # counter.fit(list(train['text'])+list(test['text']))
    # train_x = counter.transform(train['text']).toarray()
    # test_x = counter.transform(test['text']).toarray()
    # print(counter.vocabulary_) # 詞頻統計

    # 6.模型建立------算法模型
    nb = MultinomialNB()
    nb.fit(train_x, train['label_'])
    test_pre = nb.predict(test_x)

    # 7.模型評估與優化-----性能評估/參數優化
    # 得分
    print(nb.score(test_x, test['label_']))
    # 准確率
    print(f'准確率:{accuracy_score(test["label_"], test_pre)}')
    # 精確率
    print(f'精確率:{precision_score(test["label_"], test_pre, average="weighted")}')
    # 召回率
    print(f'召回率:{recall_score(test["label_"], test_pre, average="weighted")}')
    # f1_score
    print(f'f1:{f1_score(test["label_"], test_pre, average="weighted")}')

    # PR曲線/ROC曲線
     # 8. 新文本預測
    new_text = read_text('./text_classification-master/text classification/new_text', encoding='utf-8')
    new_text['text'] = cut_word(new_text['text'])
    new_text_x = counter.transform(new_text['text']).toarray() # 新文本分詞后直接使用訓練集和測試集建的詞袋模型,進行詞頻向 
                                                               # 量化,然后進行已建模型的預測
    print(new_text_x)
    map_dict =  {0: '體育', 1: '女性', 2: '文學', 3: '校園'}
    print(f'預測結果為:{[map_dict[i] for i in nb.predict(new_text_x)]}')


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM