貝葉斯的應用
-
過濾垃圾郵件
貝葉斯分類器的著名的應用就是垃圾郵件過濾了,這方面推薦想詳細了解的可以去看看《黑客與畫家》或是《數學之美》中對應的章節,貝葉斯的基礎實現看這里
數據集
兩個文件夾,分別是正常郵件和垃圾郵件,其中各有25封郵件
測試方法
從50封郵件中隨機選取10封做為測試數據
實現細節
1.首先我們需要將文本轉成我們需要的向量的樣子,這里需要使用一點正則表達式
2.由於采取交叉驗證的方式,隨機過程會導致每次的結果不盡相同
1 #coding=utf-8 2 from numpy import * 3 4 #解析文檔的函數 5 def textParse(bigString): 6 import re 7 listOfTokens = re.split(r'\W*',bigString) 8 return [tok.lower() for tok in listOfTokens if len(tok) > 2] 9 10 11 #創建一個帶有所有單詞的列表 12 def createVocabList(dataSet): 13 vocabSet = set([]) 14 for document in dataSet: 15 vocabSet = vocabSet | set(document) 16 return list(vocabSet) 17 18 def setOfWords2Vec(vocabList, inputSet): 19 retVocabList = [0] * len(vocabList) 20 for word in inputSet: 21 if word in vocabList: 22 retVocabList[vocabList.index(word)] = 1 23 else: 24 print 'word ',word ,'not in dict' 25 return retVocabList 26 27 #另一種模型 28 def bagOfWords2VecMN(vocabList, inputSet): 29 returnVec = [0]*len(vocabList) 30 for word in inputSet: 31 if word in vocabList: 32 returnVec[vocabList.index(word)] += 1 33 return returnVec 34 35 def trainNB0(trainMatrix,trainCatergory): 36 numTrainDoc = len(trainMatrix) 37 numWords = len(trainMatrix[0]) 38 pAbusive = sum(trainCatergory)/float(numTrainDoc) 39 #防止多個概率的成績當中的一個為0 40 p0Num = ones(numWords) 41 p1Num = ones(numWords) 42 p0Denom = 2.0 43 p1Denom = 2.0 44 for i in range(numTrainDoc): 45 if trainCatergory[i] == 1: 46 p1Num +=trainMatrix[i] 47 p1Denom += sum(trainMatrix[i]) 48 else: 49 p0Num +=trainMatrix[i] 50 p0Denom += sum(trainMatrix[i]) 51 p1Vect = log(p1Num/p1Denom)#處於精度的考慮,否則很可能到限歸零 52 p0Vect = log(p0Num/p0Denom) 53 return p0Vect,p1Vect,pAbusive 54 55 def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1): 56 p1 = sum(vec2Classify * p1Vec) + log(pClass1) #element-wise mult 57 p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1) 58 if p1 > p0: 59 return 1 60 else: 61 return 0 62 63 def spamTest(spamFloder, hamFloder): 64 docList = [] 65 classList = [] 66 fullText = [] 67 for i in range(1,26): 68 wordList = textParse(open(spamFloder+str(i)+'.txt').read()) 69 docList.append(wordList) 70 fullText.extend(wordList) 71 classList.append(1) 72 wordList = textParse(open(hamFloder+str(i)+'.txt').read()) 73 docList.append(wordList) 74 fullText.extend(wordList) 75 classList.append(0) 76 vocabList = createVocabList(docList) 77 trainingSet = range(50) 78 testSet = [] 79 for i in range(10): 80 randIndex = int(random.uniform(0,len(trainingSet))) 81 testSet.append(trainingSet[randIndex]) 82 del(trainingSet[randIndex]) 83 trainMat = [] 84 trianClasses = [] 85 print trainingSet 86 for docIndex in trainingSet: 87 trainMat.append(setOfWords2Vec(vocabList, docList[docIndex])) 88 #trainMat.append(bagOfWords2VecMN(vocabList, docList[docIndex])) 89 trianClasses.append(classList[docIndex]) 90 p0V,p1V,pSpam = trainNB0(array(trainMat),array(trianClasses)) 91 errorCount = 0 92 for docIndex in testSet: #classify the remaining items 93 #wordVector = bagOfWords2VecMN(vocabList, docList[docIndex]) 94 wordVector = setOfWords2Vec(vocabList, docList[docIndex]) 95 if classifyNB(array(wordVector),p0V,p1V,pSpam) != classList[docIndex]: 96 errorCount += 1 97 print "classification error",docList[docIndex] 98 print 'the error rate is: ',float(errorCount)/len(testSet) 99 #return vocabList,fullText 100 101 102 def main(): 103 spamTest('email/spam/','email/ham/') 104 105 if __name__ == '__main__': 106 main()
-
從個人廣告中獲取地區傾向
這個是從某個網站上提取了不同地區板塊的信息,分析他們的用詞是不是有某些規律
數據集
這里的數據使用RSS獲取的,用到了python的feedparse包,想了解可以看這里.這里分別獲取了某網站兩個地區板塊中的信息
測試方法
交叉驗證
實現細節
1.這里有兩種字符需要特別處理(其實他們有很大重合),一種是頻率最高的一些,另一種是所謂的停用詞(我的理解其實就是那些使用頻率很高但沒什么實際意義的),各種語言的停用詞可以看這里。 我們需要移除這些詞以使得結果更能體現出地區差異。
2.getTopWords函數實際上就是對這個概率統計了一下特征。對學習貝葉斯來說不是必要代碼
3.除了數據來源不同實現細節和上面的很相似
-
1 #coding=utf-8 2 from numpy import * 3 4 #解析文檔的函數 5 def textParse(bigString): 6 import re 7 listOfTokens = re.split(r'\W*',bigString) 8 return [tok.lower() for tok in listOfTokens if len(tok) > 2] 9 10 11 #創建一個帶有所有單詞的列表 12 def createVocabList(dataSet): 13 vocabSet = set([]) 14 for document in dataSet: 15 vocabSet = vocabSet | set(document) 16 return list(vocabSet) 17 18 def setOfWords2Vec(vocabList, inputSet): 19 retVocabList = [0] * len(vocabList) 20 for word in inputSet: 21 if word in vocabList: 22 retVocabList[vocabList.index(word)] = 1 23 else: 24 print 'word ',word ,'not in dict' 25 return retVocabList 26 27 #另一種模型 28 def bagOfWords2VecMN(vocabList, inputSet): 29 returnVec = [0]*len(vocabList) 30 for word in inputSet: 31 if word in vocabList: 32 returnVec[vocabList.index(word)] += 1 33 return returnVec 34 35 def trainNB0(trainMatrix,trainCatergory): 36 numTrainDoc = len(trainMatrix) 37 numWords = len(trainMatrix[0]) 38 pAbusive = sum(trainCatergory)/float(numTrainDoc) 39 #防止多個概率的成績當中的一個為0 40 p0Num = ones(numWords) 41 p1Num = ones(numWords) 42 p0Denom = 2.0 43 p1Denom = 2.0 44 for i in range(numTrainDoc): 45 if trainCatergory[i] == 1: 46 p1Num +=trainMatrix[i] 47 p1Denom += sum(trainMatrix[i]) 48 else: 49 p0Num +=trainMatrix[i] 50 p0Denom += sum(trainMatrix[i]) 51 p1Vect = log(p1Num/p1Denom)#處於精度的考慮,否則很可能到限歸零 52 p0Vect = log(p0Num/p0Denom) 53 return p0Vect,p1Vect,pAbusive 54 55 def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1): 56 p1 = sum(vec2Classify * p1Vec) + log(pClass1) #element-wise mult 57 p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1) 58 if p1 > p0: 59 return 1 60 else: 61 return 0 62 63 def stopWords(): 64 stopW = [] 65 f = open('stopwords.txt').readlines() 66 for eachLine in f: 67 stopW.append(eachLine[:-1]) 68 return stopW 69 70 def calcMostFreq(vocabList,fullText): 71 import operator 72 freqDict = {} 73 for token in vocabList: 74 freqDict[token]=fullText.count(token) 75 sortedFreq = sorted(freqDict.iteritems(), key=operator.itemgetter(1), reverse=True) 76 return sortedFreq[:30] 77 78 def localWords(rss1,rss0): 79 import feedparser 80 feed1 = feedparser.parse(rss1) 81 feed0 = feedparser.parse(rss0) 82 docList=[]; classList = []; fullText =[] 83 minLen = min(len(feed1['entries']),len(feed0['entries'])) 84 for i in range(minLen): 85 wordList = textParse(feed1['entries'][i]['summary']) 86 docList.append(wordList) 87 fullText.extend(wordList) 88 classList.append(1) #NY is class 1 89 wordList = textParse(feed0['entries'][i]['summary']) 90 docList.append(wordList) 91 fullText.extend(wordList) 92 classList.append(0) 93 vocabList = createVocabList(docList)#create vocabulary 94 top30Words = calcMostFreq(vocabList,fullText) #remove top 30 words 95 for pairW in top30Words: 96 if pairW[0] in vocabList: vocabList.remove(pairW[0]) 97 stopW = stopWords() 98 for pairW in stopW: 99 if pairW[0] in vocabList: 100 vocabList.remove(pairW[0]) 101 trainingSet = range(2*minLen); testSet=[] #create test set 102 for i in range(20): 103 randIndex = int(random.uniform(0,len(trainingSet))) 104 testSet.append(trainingSet[randIndex]) 105 del(trainingSet[randIndex]) 106 trainMat=[]; trainClasses = [] 107 for docIndex in trainingSet:#train the classifier (get probs) trainNB0 108 trainMat.append(bagOfWords2VecMN(vocabList, docList[docIndex])) 109 trainClasses.append(classList[docIndex]) 110 p0V,p1V,pSpam = trainNB0(array(trainMat),array(trainClasses)) 111 errorCount = 0 112 for docIndex in testSet: #classify the remaining items 113 wordVector = bagOfWords2VecMN(vocabList, docList[docIndex]) 114 if classifyNB(array(wordVector),p0V,p1V,pSpam) != classList[docIndex]: 115 errorCount += 1 116 print 'the error rate is: ',float(errorCount)/len(testSet) 117 return vocabList,p0V,p1V 118 119 def getTopWords(ny,sf): 120 import operator 121 vocabList,p0V,p1V=localWords(ny,sf) 122 topNY=[]; topSF=[] 123 for i in range(len(p0V)): 124 if p0V[i] > -6.0 : topSF.append((vocabList[i],p0V[i])) 125 if p1V[i] > -6.0 : topNY.append((vocabList[i],p1V[i])) 126 sortedSF = sorted(topSF, key=lambda pair: pair[1], reverse=True) 127 print "SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**SF**" 128 for item in sortedSF: 129 print item[0] 130 sortedNY = sorted(topNY, key=lambda pair: pair[1], reverse=True) 131 print "NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**NY**" 132 for item in sortedNY: 133 print item[0] 134 135 def main(): 136 #print stopWords() 137 localWords('http://newyork.craigslist.org/stp/index.rss','http://sfbay.craigslist.org/stp/index.rss') 138 139 if __name__ == '__main__': 140 main()
機器學習筆記索引