一、數據集介紹
數據來源:今日頭條客戶端
數據格式如下:
6551700932705387022_!_101_!_news_culture_!_京城最值得你來場文化之旅的博物館_!_保利集團,馬未都,中國科學技術館,博物館,新中國
6552368441838272771_!_101_!_news_culture_!_發酵床的墊料種類有哪些?哪種更好?_!_
6552407965343678723_!_101_!_news_culture_!_上聯:黃山黃河黃皮膚黃土高原。怎么對下聯?_!_
6552332417753940238_!_101_!_news_culture_!_林徽因什么理由拒絕了徐志摩而選擇梁思成為終身伴侶?_!_
6552475601595269390_!_101_!_news_culture_!_黃楊木是什么樹?_!_
每行為一條數據,以_!_分割的個字段,從前往后分別是 新聞ID,分類code(見下文),分類名稱(見下文),新聞字符串(僅含標題),新聞關鍵詞
分類code與名稱:
100 民生 故事 news_story
101 文化 文化 news_culture
102 娛樂 娛樂 news_entertainment
103 體育 體育 news_sports
104 財經 財經 news_finance
106 房產 房產 news_house
107 汽車 汽車 news_car
108 教育 教育 news_edu
109 科技 科技 news_tech
110 軍事 軍事 news_military
112 旅游 旅游 news_travel
113 國際 國際 news_world
114 證券 股票 stock
115 農業 三農 news_agriculture
116 電競 游戲 news_game
github地址:https://github.com/fate233/toutiao-text-classfication-dataset
數據資源中給出了分類的實驗結果:
Test Loss: 0.57, Test Acc: 83.81%
precision recall f1-score support
news_story 0.66 0.75 0.70 848
news_culture 0.57 0.83 0.68 1531
news_entertainment 0.86 0.86 0.86 8078
news_sports 0.94 0.91 0.92 7338
news_finance 0.59 0.67 0.63 1594
news_house 0.84 0.89 0.87 1478
news_car 0.92 0.90 0.91 6481
news_edu 0.71 0.86 0.77 1425
news_tech 0.85 0.84 0.85 6944
news_military 0.90 0.78 0.84 6174
news_travel 0.58 0.76 0.66 1287
news_world 0.72 0.69 0.70 3823
stock 0.00 0.00 0.00 53
news_agriculture 0.80 0.88 0.84 1701
news_game 0.92 0.87 0.89 6244
avg / total 0.85 0.84 0.84 54999
下面我們就來用deeplearning4j來實現一個卷積結構對該數據集進行分類,看能不能得到更好的結果。
二、卷積網絡可以用於文本處理的原因
CNN非常適合處理圖像數據,前面一篇文章《deeplearning4j——卷積神經網絡對驗證碼進行識別》介紹了CNN對驗證碼進行識別。本篇博客將利用CNN對文本進行分類,在開始之前我們先來直觀的說說卷積運算在做的本質事情是什么。卷積運算,本質上可以看做兩個向量的點積,兩個向量越同向,點積就越大,經過relu和MaxPooling之后,本質上是提取了與卷積核最同向的結構,這個“結構”實際上是圖片上的一些線條。
那么文本可以用CNN來處理嗎?答案是肯定的,文本每個詞用向量表示之后,依次排開,就變成了一張二維圖,如下圖,沿着紅色箭頭的方向(也就是文本的方向)看,兩個句子用一幅圖表示之后,會出現相同的單元,也就可以用CNN來處理。
三、文本處理的卷積結構
那么,怎么設計這個CNN網絡結構呢?如下圖:(論文地址:https://arxiv.org/abs/1408.5882)
注意點:
1、卷積核移動的方向必須為句子的方向
2、每個卷積核提取的特征為N行1列的向量
3、MaxPooling的操作的對象是每一個Feature Map,也就是從每一個N行1列的向量中選擇一個最大值
4、把選擇的所有最大值接起來,經過幾個Fully Connected 層,進行分類
四、數據的預處理與詞向量
1、分詞工具:HanLP
2、處理后的數據格式如下:(類別code_!_詞,其中,詞與詞之間用空格隔開,_!_為分割符)
數據預處理代碼如下:
public static void main(String[] args) throws Exception {
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
new FileInputStream(new File("/toutiao_cat_data/toutiao_cat_data.txt")), "UTF-8"));
OutputStreamWriter writerStream = new OutputStreamWriter(
new FileOutputStream("/toutiao_cat_data/toutiao_data_type_word.txt"), "UTF-8");
BufferedWriter writer = new BufferedWriter(writerStream);
String line = null;
long startTime = System.currentTimeMillis();
while ((line = bufferedReader.readLine()) != null) {
String[] array = line.split("_!_");
StringBuilder stringBuilder = new StringBuilder();
for (Term term : HanLP.segment(array[3])) {
if (stringBuilder.length() > 0) {
stringBuilder.append(" ");
}
stringBuilder.append(term.word.trim());
}
writer.write(Integer.parseInt(array[1].trim()) + "_!_" + stringBuilder.toString() + "\n");
}
writer.flush();
writer.close();
System.out.println(System.currentTimeMillis() - startTime);
bufferedReader.close();
}
五、詞的向量表示
1、one-hot
用正交的向量來表示每一個詞,這樣表示無法反應詞與詞之間的關系,那么兩句話中,要想復用同一個卷積核,那么必須出現一模一樣的詞才可以,實際上,我們要求模型可以舉一反三,連相似的結構也可以提取,那么word2vec可以解決這個問題。
2、word2vec
word2vec可以充分考慮詞與詞之間的關系,相似的詞,肯定有某些維度靠的比較近。那么也就考慮了詞的語句之間的關系,訓練word2vec有兩種,skipgram和cbow,下面我們用cbow來訓練詞向量,結果會持久化下來,就得到了toutiao.vec的文件,下次變可重新加載該文件獲得詞的向量表示,代碼如下:
String filePath = new ClassPathResource("toutiao_data_word.txt").getFile().getAbsolutePath();
SentenceIterator iter = new BasicLineIterator(filePath);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
VocabCache<VocabWord> cache = new AbstractCache<>();
WeightLookupTable<VocabWord> table = new InMemoryLookupTable.Builder<VocabWord>().vectorLength(100)
.useAdaGrad(false).cache(cache).build();
log.info("Building model....");
Word2Vec vec = new Word2Vec.Builder()
.elementsLearningAlgorithm("org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW")
.minWordFrequency(0).iterations(1).epochs(20).layerSize(100).seed(42).windowSize(8).iterate(iter)
.tokenizerFactory(t).lookupTable(table).vocabCache(cache).build();
vec.fit();
WordVectorSerializer.writeWord2VecModel(vec, "/toutiao_cat_data/toutiao.vec");
六、CNN網絡結構
CNN網絡結構如下:
說明:
1、cnn3、cnn4、cnn5、cnn6卷積核大小為(3,vectorSize)、(4,vectorSize)、(5,vectorSize)、(6,vectorSize),步幅為1,也就是分別讀取3、4、5、6個詞,提取特征
2、cnn3-stride2、cnn4-stride2、cnn5-stride2、cnn6-stride2卷積核大小為(3,vectorSize)、(4,vectorSize)、(5,vectorSize)、(6,vectorSize),步幅為2
3、兩組卷積核卷積的結果合並,分別得到merge1和merge2,都是4維張量,形狀分別為(batchSize,depth1+depth2+depth3,height/1,1),(batchSize,depth1+depth2+depth3,height/2,1),特別說明:這里的卷積模式為ConvolutionMode.Same
4、merge1、2分別經過MaxPooling,這里用的是GlobalPoolingLayer,和平台的Pooling層不同,這里會從指定維度中,取一個最大值,所以經過GlobalPoolingLayer之后,merge1、2分別變成2維張量,形狀為(batchSize,depth1+depth2+depth3),那么GlobalPoolingLayer是如何求Max的呢?源碼如下:
private INDArray activateHelperFullArray(INDArray inputArray, int[] poolDim) {
switch (poolingType) {
case MAX:
return inputArray.max(poolDim);
case AVG:
return inputArray.mean(poolDim);
case SUM:
return inputArray.sum(poolDim);
case PNORM:
//P norm: https://arxiv.org/pdf/1311.1780.pdf
//out = (1/N * sum( |in| ^ p) ) ^ (1/p)
int pnorm = layerConf().getPnorm();
INDArray abs = Transforms.abs(inputArray, true);
Transforms.pow(abs, pnorm, false);
INDArray pNorm = abs.sum(poolDim);
return Transforms.pow(pNorm, 1.0 / pnorm, false);
default:
throw new RuntimeException("Unknown or not supported pooling type: " + poolingType + " " + layerId());
}
}
5、兩邊GlobalPoolingLayer結果再接起來,丟給全連接網絡,經過softmax分類器進行分類
6、fc層,用了0.5的dropout防止過擬合,在下面的代碼中可以看到。
完整代碼如下:
public class CnnSentenceClassificationTouTiao {
public static void main(String[] args) throws Exception {
List<String> trainLabelList = new ArrayList<>();// 訓練集label
List<String> trainSentences = new ArrayList<>();// 訓練集文本集合
List<String> testLabelList = new ArrayList<>();// 測試集label
List<String> testSentences = new ArrayList<>();//// 測試集文本集合
Map<String, List<String>> map = new HashMap<>();
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(
new FileInputStream(new File("/toutiao_cat_data/toutiao_data_type_word.txt")), "UTF-8"));
String line = null;
int truncateReviewsToLength = 0;
Random random = new Random(123);
while ((line = bufferedReader.readLine()) != null) {
String[] array = line.split("_!_");
if (map.get(array[0]) == null) {
map.put(array[0], new ArrayList<String>());
}
map.get(array[0]).add(array[1]);// 將樣本中所有數據,按照類別歸類
int length = array[1].split(" ").length;
if (length > truncateReviewsToLength) {
truncateReviewsToLength = length;// 求樣本中,句子的最大長度
}
}
bufferedReader.close();
for (Map.Entry<String, List<String>> entry : map.entrySet()) {
for (String sentence : entry.getValue()) {
if (random.nextInt() % 5 == 0) {// 每個類別抽取20%作為test集
testLabelList.add(entry.getKey());
testSentences.add(sentence);
} else {
trainLabelList.add(entry.getKey());
trainSentences.add(sentence);
}
}
}
int batchSize = 64;
int vectorSize = 100;
int nEpochs = 10;
int cnnLayerFeatureMaps = 50;
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(12345);
Nd4j.getMemoryManager().setAutoGcWindow(5000);
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU)
.activation(Activation.LEAKYRELU).updater(new Nesterovs(0.01, 0.9))
.convolutionMode(ConvolutionMode.Same).l2(0.0001).graphBuilder().addInputs("input")
.addLayer("cnn3",
new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn4",
new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn5",
new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn6",
new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(1, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn3-stride2",
new ConvolutionLayer.Builder().kernelSize(3, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn4-stride2",
new ConvolutionLayer.Builder().kernelSize(4, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn5-stride2",
new ConvolutionLayer.Builder().kernelSize(5, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addLayer("cnn6-stride2",
new ConvolutionLayer.Builder().kernelSize(6, vectorSize).stride(2, vectorSize)
.nOut(cnnLayerFeatureMaps).build(),
"input")
.addVertex("merge1", new MergeVertex(), "cnn3", "cnn4", "cnn5", "cnn6")
.addLayer("globalPool1", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
"merge1")
.addVertex("merge2", new MergeVertex(), "cnn3-stride2", "cnn4-stride2", "cnn5-stride2", "cnn6-stride2")
.addLayer("globalPool2", new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(),
"merge2")
.addLayer("fc",
new DenseLayer.Builder().nOut(200).dropOut(0.5).activation(Activation.LEAKYRELU).build(),
"globalPool1", "globalPool2")
.addLayer("out",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nOut(15).build(),
"fc")
.setOutputs("out").setInputTypes(InputType.convolutional(truncateReviewsToLength, vectorSize, 1))
.build();
ComputationGraph net = new ComputationGraph(config);
net.init();
System.out.println(net.summary());
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("/toutiao_cat_data/toutiao.vec");
System.out.println("Loading word vectors and creating DataSetIterators");
DataSetIterator trainIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, trainLabelList,
trainSentences, rng);
DataSetIterator testIter = getDataSetIterator(word2Vec, batchSize, truncateReviewsToLength, testLabelList,
testSentences, rng);
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
net.setListeners(new ScoreIterationListener(100), new StatsListener(statsStorage, 20),
new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
// net.setListeners(new ScoreIterationListener(100),
// new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
net.fit(trainIter, nEpochs);
}
private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
List<String> lableList, List<String> sentences, Random rng) {
LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng);
return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors)
.minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false)
.build();
}
}
代碼說明:
1、代碼分兩部分,第一部分是數據預處理,分出20%測試集、80%作為訓練集
2、第二部分為網絡的基本結構代碼
網絡參數詳細如下:
===============================================================================================================================================
VertexName (VertexType) nIn,nOut TotalParams ParamsShape Vertex Inputs
===============================================================================================================================================
input (InputVertex) -,- - - -
cnn3 (ConvolutionLayer) 1,50 15050 W:{50,1,3,100}, b:{1,50} [input]
cnn4 (ConvolutionLayer) 1,50 20050 W:{50,1,4,100}, b:{1,50} [input]
cnn5 (ConvolutionLayer) 1,50 25050 W:{50,1,5,100}, b:{1,50} [input]
cnn6 (ConvolutionLayer) 1,50 30050 W:{50,1,6,100}, b:{1,50} [input]
cnn3-stride2 (ConvolutionLayer) 1,50 15050 W:{50,1,3,100}, b:{1,50} [input]
cnn4-stride2 (ConvolutionLayer) 1,50 20050 W:{50,1,4,100}, b:{1,50} [input]
cnn5-stride2 (ConvolutionLayer) 1,50 25050 W:{50,1,5,100}, b:{1,50} [input]
cnn6-stride2 (ConvolutionLayer) 1,50 30050 W:{50,1,6,100}, b:{1,50} [input]
merge1 (MergeVertex) -,- - - [cnn3, cnn4, cnn5, cnn6]
merge2 (MergeVertex) -,- - - [cnn3-stride2, cnn4-stride2, cnn5-stride2, cnn6-stride2]
globalPool1 (GlobalPoolingLayer) -,- 0 - [merge1]
globalPool2 (GlobalPoolingLayer) -,- 0 - [merge2]
fc-merge (MergeVertex) -,- - - [globalPool1, globalPool2]
fc (DenseLayer) 400,200 80200 W:{400,200}, b:{1,200} [fc-merge]
out (OutputLayer) 200,15 3015 W:{200,15}, b:{1,15} [fc]
-----------------------------------------------------------------------------------------------------------------------------------------------
Total Parameters: 263615
Trainable Parameters: 263615
Frozen Parameters: 0
===============================================================================================================================================
DL4J的UIServer界面如下,這里我給定的端口號為9001,打開web界面可以看到平均loss的詳情,梯度更新的詳情等
http://localhost:9001/train/overview
七、掩模
句子有長有短,CNN將如何處理呢?
處理的辦法其實很暴力,將一個minibatch中的最長句子找到,new出最大長度的張量,多余值用掩模掩掉即可,廢話不多說,直接上代碼
if(sentencesAlongHeight){
featuresMask = Nd4j.create(currMinibatchSize, 1, maxLength, 1);
for (int i = 0; i < currMinibatchSize; i++) {
int sentenceLength = tokenizedSentences.get(i).getFirst().size();
if (sentenceLength >= maxLength) {
featuresMask.slice(i).assign(1.0);
} else {
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength), NDArrayIndex.point(0)).assign(1.0);
}
}
} else {
featuresMask = Nd4j.create(currMinibatchSize, 1, 1, maxLength);
for (int i = 0; i < currMinibatchSize; i++) {
int sentenceLength = tokenizedSentences.get(i).getFirst().size();
if (sentenceLength >= maxLength) {
featuresMask.slice(i).assign(1.0);
} else {
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
}
}
}
這里為什么有個if呢?生成句子張量的時候,可以任意指定句子的方向,可以沿着矩陣中height的方向,也可以是width的方向,方向不同,填掩模的那一維也就不同。
八、結果
運行了10個Epoch結果如下:
========================Evaluation Metrics========================
# of classes: 15
Accuracy: 0.8420
Precision: 0.8362 (1 class excluded from average)
Recall: 0.7783
F1 Score: 0.8346 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 15 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [12]
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
----------------------------------------------------------------------------
973 35 114 2 9 8 11 19 14 6 19 11 0 22 13 | 0 = 0
17 4636 250 37 51 16 14 151 47 29 232 36 0 82 44 | 1 = 1
103 176 6980 108 16 8 31 62 83 41 53 77 0 36 163 | 2 = 2
9 78 244 6692 37 9 52 59 33 27 57 54 0 10 96 | 3 = 3
7 52 36 31 4072 96 101 107 581 20 64 108 0 135 37 | 4 = 4
12 18 22 8 150 3061 27 36 53 2 100 16 0 56 2 | 5 = 5
17 38 71 26 94 13 6443 43 174 31 121 39 0 32 34 | 6 = 6
17 157 93 49 62 20 34 4793 85 14 58 36 0 49 31 | 7 = 7
1 45 71 21 436 30 195 138 7018 48 54 49 0 45 148 | 8 = 8
24 74 84 47 24 1 57 50 68 3963 45 431 0 9 65 | 9 = 9
9 165 90 21 40 37 61 40 42 21 3428 111 0 78 30 | 10 = 10
47 78 173 52 114 20 48 67 93 320 140 4097 0 48 29 | 11 = 11
0 0 0 0 60 0 1 0 5 0 0 0 0 0 0 | 12 = 12
35 105 31 6 139 37 34 61 79 11 153 35 0 3187 12 | 13 = 13
14 36 210 128 31 2 19 20 164 44 38 15 0 19 5183 | 14 = 14
平均准確率0.8420,比原資源中給定的結果略好,F1 score要略差一點,混淆矩陣中,有一個類別,無法被預測到,是因為樣本中改類別數據量本身很少,難以抓到共性特征。這里參數如果精心調節一番,迭代更多次數,理論上會有更好的表現。
九、后記
讀Deeplearning4j是一種享受,優雅的架構,清晰的邏輯,多種設計模式,擴展性強,將有后續博客,對dl4j源碼進行剖析。
快樂源於分享。
此博客乃作者原創, 轉載請注明出處