python實現一個朴素貝葉斯分類方法


1.公式

上式中左邊D是需要預測的測試數據屬性,h是需要預測的類;右邊式子分子是屬性的條件概率和類別的先驗概率,可以從統計訓練數據中得到,分母對於所有實例都一樣,可以不考慮,所有只需 ,返回最大概率的那個類別。但是如果測試數據中沒有那個屬性,整個預測概率會是0;此外,此式針對離散型屬性進行訓練,針對連續的數值型屬性可以考慮分段,也可以假設其滿足某種分布,比如正態分布,利用概率密度函數求概率。

2.部分改進

(1).針對測試數據中沒有那個屬性,可以平滑一下,比如下(針對非數值型屬性):

上式中n是某個類別下的實例數,nc是此類別下的屬性個數,m是此屬性的取值個數,p是此屬性取值出現的概率。比如一個屬性:性別,取值男或女,則 m=2,p=1/2。

(2).針對連續的數值型屬性,可以分段比如年齡0-10為A,10-30為B等;還可以假設它服從高斯分布(正態分布),利分布函數計算概率:

其中uij是某列數值型屬性的均值,Qij是某列數值型屬性樣本標准差,Xi是數值屬性。訓練的時候只需要統計均值,樣本標准差就行了,預測的時候利用。

3.python實現

  1 #!/usr/bin/python
  2 # -*- coding: utf-8 -*-
  3 
  4 import codecs
  5 import math
  6 
  7 class BayesClassifier:
  8 
  9     def __init__(self,dataFormat):
 10         self.prior = {}#類別的先驗概率
 11         self.conditional = {}#屬性的條件概率
 12         # 輸入的數據格式,attr表示非數值型屬性,num表示數值型屬性,class表示類別
 13         self.format=dataFormat.strip().split('\t')
 14 
 15     #讀取數據
 16     def readData(self,dataFile):
 17         total = 0#所有實例數
 18         self.classes = {}#統計類別
 19         self.counts = {}#用來統計
 20         totals={}#統計數值型每列的和
 21         numericValues={}#數值型每列值
 22 
 23         with codecs.open(dataFile,'r','utf-8') as f:
 24             for line in f:
 25                 fields=line.strip().split('\t')
 26                 fieldSize=len(fields)
 27                 vector=[]
 28                 nums=[]
 29                 for i in range(fieldSize):
 30                     if self.format[i]=='num':
 31                         nums.append(float(fields[i]))
 32                     elif self.format[i]=='attr':
 33                         vector.append(fields[i])
 34                     elif self.format[i]=='class':
 35                         category=fields[i]
 36                 total+=1
 37                 self.classes.setdefault(category,0)
 38                 self.counts.setdefault(category,{})
 39                 totals.setdefault(category,{})
 40                 numericValues.setdefault(category,{})
 41                 self.classes[category]+=1
 42                 #統計一條非數值型實例的屬性
 43                 col=0
 44                 for columnValue in vector:
 45                     col+=1
 46                     self.counts[category].setdefault(col,{})
 47                     self.counts[category][col].setdefault(columnValue,0)
 48                     self.counts[category][col][columnValue]+=1
 49                 col=0
 50                 for columnValue in nums:
 51                     col+=1
 52                     totals[category].setdefault(col,0)
 53                     totals[category][col]+=columnValue
 54                     numericValues[category].setdefault(col,[])
 55                     numericValues[category][col].append(columnValue)
 56 
 57         #以上統計完成,計算類別先驗概率和屬性條件概率
 58         #計算類的先驗概率=此類的實例數/總的實例數
 59         for category,count in self.classes.items():
 60             self.prior[category]=count/total
 61         #計算屬性的條件概率=此類中屬性數/此類實例數
 62         for category,columns in self.counts.items():
 63             self.conditional.setdefault(category,{})
 64             for col,valueCounts in columns.items():
 65                 self.conditional[category].setdefault(col,{})
 66                 colSize=len(valueCounts)#這一列屬性的取值個數(如性別取值為男和女,則colSize=2)
 67                 for attr,count in valueCounts.items():
 68                     #平滑一下
 69                     self.conditional[category][col][attr]=(count+colSize*1/colSize)/(self.classes[category]+colSize)
 70         #在數值型列中計算均值和樣本標准差
 71         #每列的均值
 72         self.means={}
 73         self.totals=totals
 74         for category,columns in totals.items():
 75             self.means.setdefault(category,{})
 76             for col,colSum  in columns.items():
 77                 self.means[category][col]=colSum/self.classes[category]
 78         #每列的標准差
 79         self.std={}
 80         for category,columns in numericValues.items():
 81             self.std.setdefault(category,{})
 82             for col,values in columns.items():
 83                 ssd=0
 84                 mean=self.means[category][col]
 85                 for value in values:
 86                     ssd+=(value-mean)**2
 87                 self.std[category][col]=math.sqrt(ssd/(self.classes[category]-1))
 88 
 89 
 90     #分類,返回分類結果
 91     def classify(self,itemVector):
 92         results=[]
 93         for category,prior in self.prior.items():
 94             prob=prior
 95             col=1
 96             for attrValue in itemVector:
 97                 if self.format[col]=='attr':
 98                     # 如果預測數據沒有這個屬性,則平滑一下,不是返回0(返回0會導致整個預測結果為0)
 99                     if not attrValue in self.conditional[category][col]:
100                         colSize=len(self.counts[category][col])
101                         prob=prob*(0+colSize*1/colSize)/(self.classes[category]+colSize)
102                     else:
103                         prob=prob*self.conditional[category][col][attrValue]
104                 #針對數值型,我們先得到該列均值與樣本標准差,利用正態分布得到概率(假設該列數值滿足正態分布)
105                 elif self.format[col]=='num':
106                     mean=self.means[category][col]
107                     std=self.std[category][col]
108                     prob=prob*self.gaussian(mean,std,attrValue)
109                 col+=1
110             results.append((prob,category))
111         return max(results)[1]
112 
113     #高斯分布
114     def gaussian(self,mean,std,x):
115         sqrt2pi = math.sqrt(2 * math.pi)
116         ePart=math.pow(math.e,-(x-mean)**2/(2*std**2))
117         prob=(1.0/sqrt2pi*std)*ePart
118         return prob
119 
120     # 十折驗證讀取數據,prefix為文件名前綴,i作為測試集編號
121     def tenFoldReadData(self,prefix,testNumber):
122         total = 0  # 所有實例數
123         self.classes = {}  # 統計類別
124         self.counts = {}  # 用來統計
125         totals = {}  # 統計數值型每列的和
126         numericValues = {}  # 數值型每列值
127 
128         for i in range(1,11):
129             if i!=testNumber:
130                 filename='%s-%02s' % (prefix,i)
131                 with codecs.open(filename, 'r', 'utf-8') as f:
132                     for line in f:
133                         fields = line.strip().split('\t')
134                         fieldSize = len(fields)
135                         vector = []
136                         nums = []
137                         for i in range(fieldSize):
138                             if self.format[i] == 'num':
139                                 nums.append(float(fields[i]))
140                             elif self.format[i] == 'attr':
141                                 vector.append(fields[i])
142                             elif self.format[i] == 'class':
143                                 category = fields[i]
144                         total += 1
145                         self.classes.setdefault(category, 0)
146                         self.counts.setdefault(category, {})
147                         totals.setdefault(category, {})
148                         numericValues.setdefault(category, {})
149                         self.classes[category] += 1
150                         # 統計一條非數值型實例的屬性
151                         col = 0
152                         for columnValue in vector:
153                             col += 1
154                             self.counts[category].setdefault(col, {})
155                             self.counts[category][col].setdefault(columnValue, 0)
156                             self.counts[category][col][columnValue] += 1
157                         col = 0
158                         for columnValue in nums:
159                             col += 1
160                             totals[category].setdefault(col, 0)
161                             totals[category][col] += columnValue
162                             numericValues[category].setdefault(col, [])
163                             numericValues[category][col].append(columnValue)
164 
165         # 以上統計完成,計算類別先驗概率和屬性條件概率
166         # 計算類的先驗概率=此類的實例數/總的實例數
167         for category, count in self.classes.items():
168             self.prior[category] = count / total
169         # 計算屬性的條件概率=此類中屬性數/此類實例數
170         for category, columns in self.counts.items():
171             self.conditional.setdefault(category, {})
172             for col, valueCounts in columns.items():
173                 self.conditional[category].setdefault(col, {})
174                 colSize = len(valueCounts)  # 這一列屬性的取值個數(如性別取值為男和女,則colSize=2)
175                 for attr, count in valueCounts.items():
176                     # 平滑一下
177                     self.conditional[category][col][attr] = (count + colSize * 1 / colSize) / (
178                     self.classes[category] + colSize)
179         # 在數值型列中計算均值和樣本標准差
180         # 每列的均值
181         self.means = {}
182         self.totals = totals
183         for category, columns in totals.items():
184             self.means.setdefault(category, {})
185             for col, colSum in columns.items():
186                 self.means[category][col] = colSum / self.classes[category]
187         # 每列的標准差
188         self.std = {}
189         for category, columns in numericValues.items():
190             self.std.setdefault(category, {})
191             for col, values in columns.items():
192                 ssd = 0
193                 mean = self.means[category][col]
194                 for value in values:
195                     ssd += (value - mean) ** 2
196                 self.std[category][col] = math.sqrt(ssd / (self.classes[category] - 1))
197 
198     #利用十折交叉驗證,測試一個桶中的數據,prefix為統計文件名前綴,testNumber為要測試的一個桶中的數據
199     def testOneBucket(self,prefix,testNumber):
200         filename='%s-%02i' % (prefix,testNumber)
201         totals={}
202         with codecs.open(filename,'r','utf-8') as f:
203             for line in f:
204                 data=line.strip().split('\t')
205                 itemVector=[]
206                 classInColumn=-1
207                 for i in range(len(self.format)):
208                     if self.format[i]=='num':
209                         itemVector.append(float(data[i]))
210                     elif self.format[i]=='attr':
211                         itemVector.append(data[i])
212                     elif self.format[i]=='class':
213                         classInColumn=i
214                 realClass=data[classInColumn]#真實的類
215                 classifiedClass=self.classify(itemVector)#預測的類
216                 totals.setdefault(realClass,{})
217                 totals[realClass].setdefault(classifiedClass,0)
218                 totals[realClass][classifiedClass]+=1
219         return totals
220 
221 #十折交叉驗證,prefix為十個文件名字的前綴,dataForamt為數據格式
222 def tenfold(prefix,dataFormat):
223     results={}
224     for i in range(1,11):
225         classify=BayesClassifier(dataFormat)
226         classify.tenFoldReadData(prefix,i)
227         totals=classify.testOneBucket(prefix,i)
228         for key,value in totals.items():
229             results.setdefault(key,{})
230             for ckey,cvalue in value.items():
231                 results[key].setdefault(ckey,0)
232                 results[key][ckey]+=cvalue
233     #結果展示
234     classes=list(results.keys())
235     classes.sort()
236     print(      '\n                 classes as: ')
237     header='                '
238     subheader='               +'
239     for cls in classes:
240         header+='%  10s '% cls
241         subheader+='--------+'
242     print(header)
243     print(subheader)
244     total=0.0
245     correct=0.0
246     for cls in classes:
247         row=' %10s   |' % cls
248         for c2 in classes:
249             if c2 in results[cls]:
250                 count=results[cls][c2]
251             else:
252                 count=0
253             row+=' %5i |' % count
254             total+=count
255             if c2==cls:
256                 correct+=count
257         print(row)
258     print(subheader)
259     print('\n%5.3f 正確率' % ((correct*100/total)))
260     print('總共 %i 實例'% total)
261 
262 if __name__=='__main__':
263     #classify=BayesClassifier('num,num,num,num,num,num,num,num,class')
264     #classify.readData('dataFile')
265     #print(classify.classify([2,120,54,0,0,26.8,0.455,27]))
266     tenfold('dataFilePrefix','num,num,num,num,num,num,num,num,class')#十折交叉驗證

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM