pytorch實現自己的textCNN


對於初學深度學習的人來說,直接上手NLP的梯度較大。

首先,理解詞向量就有一定的困難。關於詞向量的的詳細描述,可以參考《word2vec Parameter Learning Explained》的解釋。一個100列的詞向量可以簡單理解為有100個特征(feature)的向量,如同一個人有100個特征一樣,這100個特征“完備”的描述了這個人的所有性質。

簡單理解了詞向量之后,作為初學者,肯定想自己訓練一個自己的詞向量,以加深理解。作為一個工程師,如何驗證自己的正確性是在開始編碼之前就需要考慮的。於是,我們把眼光瞄向了已經實現word2vec的第三方庫中...

 python的gensim庫了解一下...作為搞NLP的人這個庫必須知道,我們這里用它來生成詞向量。然后,接下來我們要處理中文,不可避免的需要進行中文分詞,這里我們選擇了jieba分詞(不是jiba)。然后是我們實驗需要使用的數據集,要做文本分類,這里我們找到的一個還不錯的數據集:thucnews, 大家可以在網上自行搜索;為了方便大家測試使用,這里給出我使用的下載鏈接:  提取碼: aqkg

這個數據集有14個類別,每個類別是一個文件夾,每個文件夾中包含了許多txt文件,每個txt文件是一段對應類別的新聞描述。接下來要做的,是利用數據集生成我們自己的第一份詞向量。其次,利用詞向量和數據集進行textCNN的訓練,在這個過程中,我們將會不斷的犯錯,在錯誤中學習。

源碼參考:https://github.com/webbery/NLPToolset

1. 生成詞向量

wv_thucnews.py代碼的主要功能是遍歷所有文件,對所有的文本創建對應的標簽列表,以及文本列表,這兩個列表是一一對應的;此外,在遍歷的同時還對所有文本進行了分詞操作。

gensim庫提供了LineSentence類,這個類支持從文本中讀取所有的句子。它對文本的格式要求是這樣的,它把每一行當成一個句子,並把所有的句子做成list返回。因此,如果要使用LineSentence,我們就需要按照這個格式來制作它的輸入文件。

zh.seg.txt存儲的格式如下圖所示,該文件即作為LineSentence的輸入。在設置好數據集的路徑之后,運行to_vector.py,第一個參數是已經分詞完畢的輸入數據,這里即zh.seg.txt,第二個參數是要保存的詞向量模型文件名。在代碼中,設置的詞向量的維度是100個feature。曾經試過800feature的,內存需要提供16G才能訓練完orz。因此如果只是學習使用,100個feature已經足夠了。

 

2. 數據預處理

在wv_thucnews.py階段已經產生了所有數據以及這些數據對應的標簽,因此這一階段的主要目標是對原始數據做一次shuffle,使樣本之間盡可能的無關。因此在載入數據之后,調用了train_test_split。

trainX, testX, trainY, testY = train_test_split(content,targets)

trainX是待訓練的數據,trainY是每個訓練數據對應的標簽。注意,到目前為止這里的數據還是字符串格式的,因此接下來要把訓練數據轉成詞向量格式,標簽數據轉成數字形式。

to_categorical 函數的主要任務就是將標簽字符串轉為數字形式,因為數據集的類型總共有14類,因此標簽的值從0~13。
然后在代碼中使用yield語法返回batch_size個數據,提供訓練。
PS:這里建議預先查看下數據集中一個句子的長度,這樣訓練時可以設置句子平均的合適長度, 長度不足的部分補0。

3. textCNN

我們使用pytorch編寫模型。按照《深度學習》聖經的說法,“卷積網絡中一個典型層包含三級”,第一級計算多個卷積,第二級通過一個非線性的激活函數,第三級將通過池化函數調整輸出。因此我們經常看到這樣的寫法:

 

1 self.conv1 = nn.Sequential(
2             nn.Conv1d(vec_dim,64,kernel_size,padding=2),
3             nn.ReLU(),
4             nn.MaxPool1d(3,stride=1,padding=1)
5         )

 

Convolutional Neural Networks for Sentence Classification》的textCNN中使用了三個卷積網絡。如果需要驗證論文的同學,可以嚴格按照文章的網絡和參數進行設置。在我的textCNN.py實現中,使用了三個網絡,但跟論文的參數並不相同,因為自行訓練得到的詞向量維度和數據的batch大小不同。

 

4. 如何調試

對於所有的程序來說,調試是必不可少的。觀察loss和accuracy是最基本的方法。理論上來說loss曲線應該是逐漸下降且逼近於一個值,accuracy應該是逐漸上升且逼近一個值。如果中間有較大的跳躍,可能是算法編寫的地方有問題,一般說來,可能是數據集沒有suffle,每個數據之間關聯性較大;也有可能是產生了過擬合,等等。例如下面兩個曲線是我沒有對label標簽去重導致的...

 

而正確的accuracy圖應該類似於下方這張:

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM