代碼放在github上:click me
一、數據說明
數據集為英文語料集,一共包含20種類別的郵件,除了類別soc.religion.christian的郵件數為997以外每個類別的郵件數都是1000。每份郵件內部包含發送者,接受者,正文等信息。
二、實驗方法
2.1 數據預處理
數據預處理階段采用了幾種方案進行測試
-
直接將郵件內容按空格分詞
-
使用stanford corenlp進行分詞,然后使用停詞表過濾分詞結果
-
使用stanford corenlp進行分詞,並根據詞性和停詞表過濾分詞結果
綜合上面三種方案,測試結果最好的是方案二的預處理方式。將所有的郵件預處理之后寫入一個文件中,文件每行對應一封郵件,形式如"類別\t按空格分隔的郵件分詞"
comp.os.ms-windows.misc I search Ms-Windows logo picture start Windows
misc.forsale ITEMS for SALEI offer item I reserve the right refuse offer Howard Miller
comp.sys.ibm.pc.hardware I hd bad suggest inadequate power supply how wattage
2.2 pipeline建模
基於spark的pipeline構建端到端的分類模型
-
將數據預處理后得到的文件上傳到hdfs上,spark從hdfs上讀取文本數據並轉換成DataFrame
-
為DataFrame的郵件類別列建立索引,然后將DataFrame作為Word2Vec的輸入獲取句子的向量表示
-
句子向量輸入到含有2層隱藏層的多層感知機(MLP)中進行分類學習
-
將預測結果的索引列轉換成可讀的郵件類別標簽
三、實驗結果
將數據集隨機划分成8:2,80%的數據作為訓練集,20%的數據作為測試集。經過合理的調參,在測試集上的accuracy和F1 score可以達到90.5%左右,關鍵參數設置如下
// Word2Vec超參
final val W2V_MAX_ITER = 5 // Word2Vec迭代次數
final val EMBEDDING_SIZE = 128 // 詞向量長度
final val MIN_COUNT = 1 // default: 5, 詞匯表閾值即至少出現min_count次才放入詞匯表中
final val WINDOW_SIZE = 5 // default: 5, 上下文窗口大小[-WINDOW_SIZE,WINDOW_SIZE]
// MLP超參
final val MLP_MAX_ITER = 300 // MLP迭代次數
final val SEED = 1234L // 隨機數種子,初始化網絡權重用
final val HIDDEN1_SIZE = 64 // 第一層隱藏層節點數
final val HIDDEN2_SIZE = 32 // 第二層隱藏層節點數
final val LABEL_SIZE = 20 // 輸出層節點數
郵件預測結果輸出在hdfs上,文件內容每行的de左邊是真實label,右邊是預測label

四、實驗運行
4.1 環境要求
hadoop-2.7.5
spark-2.3.0
stanford corenlp 3.9.2
4.2 源代碼說明
Maven項目文件結構如下

src/main/scala下為源代碼,其中Segment.java和EnglishSegment.java用於英文分詞,DataPreprocess.scala基於分詞作數據預處理,MailClassifier.scala對應郵件分類模型。input下為數據集,output下為數據預處理結果MailCollection和預測結果prediction,target下為maven打好的jar包Mail.jar以及運行腳本submit.sh,pom.xml為maven配置。
4.3 運行方式
將數據集20_newsgroup放在input目錄下,確保pom.xml中的依賴包都滿足以后運行DataPreprocess得到預處理的結果MailCollection輸出到output目錄下。啟動hadoop的hdfs,將MailCollection上傳到hdfs上以便spark讀取。然后啟動spark,命令行下進入到target路徑下運行./submit.sh提交任務,submit.sh內容如下
spark-submit --class MailClassifier --master spark://master:7077 --conf spark.driver.memory=10g --conf spark.executor.memory=4g --conf spark.executor.cores=2 --conf spark.kryoserializer.buffer=512m --conf spark.kryoserializer.buffer.max=1g Mail.jar input/MailCollection output
運行MailClassifier需要兩個命令行參數,其中input/MailCollection為上傳到hdfs上的路徑名,output為預測結果輸出到hdfs上的路徑名,提交任務前確保輸出路徑在hdfs上不存在,否則程序會刪除輸出輸出路徑以確保程序正確運行。