朴素貝葉斯方法(二分類)[機器學習實戰]


數據鏈接

垃圾短信分類

解析

設一個點(x,y),對(x,y)進行分類(1,2),我們可以設每個點分別屬於兩個類別的概率:

如果p1(x,y) > p2(x,y),那么類別為1
如果p1(x,y) < p2(x,y),那么類別為2

由貝葉斯概率我們有

\[p(c|x,y) = \frac {p(x,y|c)p(c)}{p(x,y)}\dots(1) \]

對於二分類可見
$$p1 \rightarrow p(1|x,y)$$
$$p2 \rightarrow p(2|x,y)$$

觀察公式一右邊.
根據大數定理,當數據集具有一定規模時,我們可以以頻率逼近概率。
右邊的概率可以由統計而得

因此朴素貝葉斯法則主要在於對數據的統計,步驟如下:

  1. 分詞,生成詞向量空間(英文文本無需如此,中文文本可以使用jieba分詞工具)
  2. 對於每個向量,計算其向量空間坐標(每個特征詞出現次數,即詞袋)
  1)計算出p(c),即每個類別出現概率
  2)對於p(x,y),可以統計出所有的單變量,再使用乘法原理即可
  3)對於p(x,y|c)可以統計類別c下的所有(x,y)的出現次數
  1. 計算對於給定詞向量的p1,p2,答案為其中值較大者

《機器學習實戰中》給出了一個優化:

考慮概率很小以及一些為0的值會導致乘完出現0,所以使用對數代替p(由於對數函數是單調遞增函數,因此同樣很好度量)

from numpy import *

def textParse1(vec):    
    return 1 if vec[0] == 'spam' else 0,vec[1:];
    
def textParse2(vec): 
    return vec[0],vec[1:];
    
def bagOfWords2VecMN(vocabList, inputSet):
    returnVec = [0]*len(vocabList)
    for word in inputSet:
        if word in vocabList:
            returnVec[vocabList.index(word)] += 1
    return returnVec

def setOfWords2VecMN(vocabList, inputSet):
    returnVec = [0]*len(vocabList)
    for word in inputSet:
        if word in vocabList:
            returnVec[vocabList.index(word)] = 1
    return returnVec

def createVocabList(dataSet):
    vocabSet = set([])  
    for document in dataSet:
        vocabSet = vocabSet | set(document) 
    return list(vocabSet)

def tfIdf(trainMatrix,setMatrix):
    n = len(trainMatrix)
    m = len(trainMatrix[0])
    d = [n]*n;
    tb = sum(trainMatrix,axis=1)
    tc = sum(setMatrix,axis=0)
    b = array(tb,dtype='float')
    c = array(tc,dtype='float')
    weight = []
    for i in range(m):
        a = trainMatrix[:,i]
        tf = a/b
        weight.append(tf * log(d/(c[i])))
    returnVec = array(weight).transpose()
    return returnVec
    

def trainNB0(trainMatrix,trainCategory,weight):
    numTrainDocs = len(trainMatrix)
    numWords = len(trainMatrix[0])
    pAbusive = sum(trainCategory)/float(numTrainDocs)
    p0Num = ones(numWords); p1Num = ones(numWords)     
    p0Denom = 2.0; p1Denom = 2.0      
    a = 0;b = 0
    a += trainMatrix[0];b += sum(trainMatrix[0])
    for i in range(numTrainDocs):
        if trainCategory[i] == 1:
            p1Num += trainMatrix[i]*weight[i]
            p1Denom += sum(trainMatrix[i]*weight[i])
        else:
            p0Num += trainMatrix[i]*weight[i]
            p0Denom += sum(trainMatrix[i]*weight[i])
    p1Vect = log(p1Num/p1Denom)     
    p0Vect = log(p0Num/p0Denom)    
    return p0Vect,p1Vect,pAbusive

def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):
    p1 = sum(vec2Classify * p1Vec) + log(pClass1)    
    p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)
    if p1 > p0:
        return 1
    else: 
        return 0

def spamTest():
    trainFile = './train.csv'
    testFile = './test.csv'
    import csv
    docList=[]; classList = []; fullText =[]
    in1 = open(trainFile);in1.readline()
    fr1 = csv.reader(in1)
    trainData = [row for row in fr1]
    
   #prepare trainData
    
    n = 0
    for i in trainData:
        label,wordList = textParse1(i)
        docList.append(wordList)
        classList.append(label)
        n += 1
        
    vocabList = createVocabList(docList)
    
    trainMat = [];trainClasses = []
    setMat = []
    
    for docIndex in range(n):
        trainMat.append(bagOfWords2VecMN(vocabList,docList[docIndex]))
        setMat.append(setOfWords2VecMN(vocabList,docList[docIndex]))
        trainClasses.append(classList[docIndex])
    
    #traiing by bayes
    weight = tfIdf(array(trainMat),array(setMat))
    
    p0V,p1V,pSpam = trainNB0(array(trainMat),array(trainClasses),array(weight))
    
    
    #prepare testData
    
    in2 = open(testFile);in2.readline()
    fr2 = csv.reader(in2)
    fw = csv.writer(open('predict.csv', 'w'))
    name = ['SmsId','Label']
    fw.writerow(name)
    testData = [row for row in fr2]
    
    #predict testData
    for i in testData:
        id,wordList = textParse2(i)
        wordVector = bagOfWords2VecMN(vocabList, wordList)
        fw.writerow([id,'spam' if classifyNB(array(wordVector),p0V,p1V,pSpam) else 'ham'])
    
    #print 'fianl point'


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM