朴素貝葉斯分類器(離散型)算法實現(一)


 

1. 貝葉斯定理:    

   (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A) 

 由(1)得

   P(A|B) = P(B|A)*P(A)/[p(B)]

 

貝葉斯在最基本題型:

假定一個場景,在一所高中男女比例為4:6, 留長頭發的有男學生有女學生, 我們設定女生都留長發 , 而男生中有10%的留長發,90%留短發.那么如果我們看到遠處一個長發背影?請問是一只男學生的概率?

  分析:

    P(男|長發) = P(長發|男)*P(男)/[p(長發)] 

        = (1/10)*(4/10)/[(6+4*(1/10))/10]

        =1/16 =0.0625

   P(女|長發) =P(長發|女)*P(女)/[p(長發)]

                  =1*(6/10)/[(6+4*(1/10))/10]

                 =30/32 =15/16

 

再舉一個列子:

某個醫院早上收了六個門診病人,如下表。

  症狀  職業   疾病

  打噴嚏 護士   感冒 
  打噴嚏 農夫   過敏 
  頭痛  建築工人 腦震盪 
  頭痛  建築工人 感冒 
  打噴嚏 教師   感冒 
  頭痛  教師   腦震盪

現在又來了第七個病人,是一個打噴嚏的建築工人。請問他患上感冒的概率有多大?(來源: http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html)

Java代碼實現:

 1 /**
 2  * *********************************************************
 3  * <p/>
 4  * Author:     XiJun.Gong
 5  * Date:       2016-08-31 20:36
 6  * Version:    default 1.0.0
 7  * Class description:
 8  * <p>特征庫</p>
 9  * <p/>
10  * *********************************************************
11  */
12 
13 public class FeaturePoint {
14 
15     private String key;
16     private double p;
17 
18     public FeaturePoint(String key) {
19         this(key, 1);
20     }
21 
22     public FeaturePoint(String key, double p) {
23         this.key = key;
24         this.p = p;
25     }
26 
27     public String getKey() {
28         return key;
29     }
30 
31     public void setKey(String key) {
32         this.key = key;
33     }
34 
35     public double getP() {
36         return p;
37     }
38 
39     public void setP(double p) {
40         this.p = p;
41     }
42 }
 1 import com.google.common.collect.ArrayListMultimap;
 2 import com.google.common.collect.Multimap;
 3 
 4 import java.util.Collection;
 5 import java.util.List;
 6 
 7 /**
 8  * *********************************************************
 9  * <p/>
10  * Author:     XiJun.Gong
11  * Date:       2016-08-31 15:48
12  * Version:    default 1.0.0
13  * Class description:
14  * <p/>
15  * *********************************************************
16  */
17 
18 public class Bayes {
19     private static Multimap<String, FeaturePoint> map = ArrayListMultimap.create();
20 
21     /*喂數據*/
22     public void input(List<String> labels) {
23 
24         for (String key : labels) {
25             Collection<FeaturePoint> features = map.get(key);
26             for (String value : labels) {
27                 if (features == null || features.size() < 1) {
28                     map.put(key, new FeaturePoint(value));
29                     continue;
30                 }
31                 boolean tag = false;
32                 for (FeaturePoint feature : features) {
33                     if (feature.getKey().equals(value)) {
34                         Double num = feature.getP() + 1;
35                         map.remove(key, feature);
36                         map.put(key, new FeaturePoint(value, num));
37                         tag = true;
38                         break;
39                     }
40                 }
41                 if (!tag)
42                     map.put(key, new FeaturePoint(value));
43             }
44         }
45     }
46 
47     /*構造模型*/
48     public void excute(List<String> labels) {
49         //   excute(labels, null);
50     }
51 
52     /*構造模型*/
53     public Double excute(final List<String> labels, final String judge, Integer dataSize) {
54 
55         Double denominator = 1d;    //分母
56         Double numerator = 1d;      //分子
57         Double coughNum = 0d;
58        /*選擇相關性分子*/
59         Collection<FeaturePoint> featurePoints = map.get(judge);
60         for (FeaturePoint featurePoint : featurePoints) {
61             if (judge.equals(featurePoint.getKey())) {
62                 coughNum = featurePoint.getP();
63                 denominator *= (featurePoint.getP() / dataSize);
64                 break;
65             }
66         }
67 
68         Integer size = featurePoints.size() - 1; //容量
69         for (String label : labels) {
70             for (FeaturePoint featurePoint : featurePoints) {
71                 if (label.equals(featurePoint.getKey())) {
72                     denominator *= (featurePoint.getP() / coughNum);
73                     for (FeaturePoint feature : map.get(label)) {
74                         if (label.equals(feature.getKey())) {
75                             numerator *= (feature.getP() / dataSize);
76                         }
77                     }
78                 }
79             }
80         }
81 
82         return denominator / numerator;
83     }
84 
85 }

 

 1 import com.google.common.collect.Lists;
 2 
 3 import java.util.List;
 4 import java.util.Scanner;
 5 
 6 /**
 7  * *********************************************************
 8  * <p/>
 9  * Author:     XiJun.Gong
10  * Date:       2016-09-01 14:58
11  * Version:    default 1.0.0
12  * Class description:
13  * <p/>
14  * *********************************************************
15  */
16 public class Main {
17 
18     public static void main(String args[]) {
19 
20         Scanner scanner = new Scanner(System.in);
21         Integer size = scanner.nextInt();
22         Integer row = scanner.nextInt();
23         Bayes bayes = new Bayes();
24         while (scanner.hasNext()) {
25 
26             for (int ro = 0; ro < row; ro++) {
27                 List<String> list = Lists.newArrayList();
28                 for (int i = 0; i < size; i++) {
29                     list.add(scanner.next());
30                 }
31                 bayes.input(list);
32             }
33             List<String> list = Lists.newArrayList();
34             for (int i = 0; i < size - 1; i++) {
35                 list.add(scanner.next());
36             }
37             String judge = scanner.next();
38             System.out.println(bayes.excute(list, judge,row));
39             ;
40         }
41 
42     }
43 }

pom.xml包

    <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>3.8.1</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>18.0</version>
        </dependency>

結果:

1 3 6
2 打噴嚏 護士   感冒 
3   打噴嚏 農夫   過敏 
4   頭痛  建築工人 腦震盪 
5   頭痛  建築工人 感冒 
6   打噴嚏 教師   感冒 
7   頭痛  教師   腦震盪
8 打噴嚏  建築工人 感冒
9 0.6666666666666666 
1 3 6
2   打噴嚏 護士   感冒 
3   打噴嚏 農夫   過敏 
4   頭痛  建築工人 腦震盪 
5   頭痛  建築工人 感冒 
6   打噴嚏 教師   感冒 
7   頭痛  教師   腦震盪
8 打噴嚏 護士   感冒 
9 1.3333333333333333

 

 1 2 50
 2 男  長發
 3 男  短發
 4 男  短發
 5 男  短發
 6 男  短發
 7 男  短發
 8 男  短發
 9 男  短發
10 男  短發
11 男  短發
12 男  短發
13 男  短發
14 男  短發
15 男  短發
16 男  短發
17 男  短發
18 男  短發
19 男  短發
20 男  短發
21 男  長發
22 女  長發
23 女  長發
24 女  長發
25 女  長發
26 女  長發
27 女  長發
28 女  長發
29 女  長發
30 女  長發
31 女  長發
32 女  長發
33 女  長發
34 女  長發
35 女  長發
36 女  長發
37 女  長發
38 女  長發
39 女  長發
40 女  長發
41 女  長發
42 女  長發
43 女  長發
44 女  長發
45 女  長發
46 女  長發
47 女  長發
48 女  長發
49 女  長發
50 女  長發
51 女  長發
52             
53 長發 男
54 0.06250000000000001
 1 2 50
 2 男  長發
 3 男  短發
 4 男  短發
 5 男  短發
 6 男  短發
 7 男  短發
 8 男  短發
 9 男  短發
10 男  短發
11 男  短發
12 男  短發
13 男  短發
14 男  短發
15 男  短發
16 男  短發
17 男  短發
18 男  短發
19 男  短發
20 男  短發
21 男  長發
22 女  長發
23 女  長發
24 女  長發
25 女  長發
26 女  長發
27 女  長發
28 女  長發
29 女  長發
30 女  長發
31 女  長發
32 女  長發
33 女  長發
34 女  長發
35 女  長發
36 女  長發
37 女  長發
38 女  長發
39 女  長發
40 女  長發
41 女  長發
42 女  長發
43 女  長發
44 女  長發
45 女  長發
46 女  長發
47 女  長發
48 女  長發
49 女  長發
50 女  長發
51 女  長發
52 長發 女
53 0.9375
View Code

 利用貝葉斯進行分類?

  1 import com.google.common.collect.ArrayListMultimap;
  2 import com.google.common.collect.Lists;
  3 import com.google.common.collect.Multimap;
  4 
  5 import java.util.Collection;
  6 import java.util.List;
  7 
  8 /**
  9  * *********************************************************
 10  * <p/>
 11  * Author:     XiJun.Gong
 12  * Date:       2016-08-31 15:48
 13  * Version:    default 1.0.0
 14  * Class description:
 15  * <p/>
 16  * *********************************************************
 17  */
 18 
 19 public class Bayes {
 20     private Multimap<String, FeaturePoint> map = null;
 21     private List<String> featurePool = null;
 22 
 23     public Bayes() {
 24         map = ArrayListMultimap.create();
 25         featurePool = Lists.newArrayList();
 26     }
 27 
 28     public void add(String label) {
 29         featurePool.add(label);
 30     }
 31 
 32     /*喂數據*/
 33     public void input(List<String> labels) {
 34 
 35         for (String key : labels) {
 36             Collection<FeaturePoint> features = map.get(key);
 37             for (String value : labels) {
 38                 if (features == null || features.size() < 1) {
 39                     map.put(key, new FeaturePoint(value));
 40                     continue;
 41                 }
 42                 boolean tag = false;
 43                 for (FeaturePoint feature : features) {
 44                     if (feature.getKey().equals(value)) {
 45                         Double num = feature.getP() + 1;
 46                         map.remove(key, feature);
 47                         map.put(key, new FeaturePoint(value, num));
 48                         tag = true;
 49                         break;
 50                     }
 51                 }
 52                 if (!tag)
 53                     map.put(key, new FeaturePoint(value));
 54             }
 55         }
 56     }
 57 
 58     /*最符合那個分類*/
 59     public String excute(List<String> labels, Integer dataSize) {
 60 
 61         Double max = -999999999d;
 62         String max_obj = null;
 63         List<Double> ans = Lists.newArrayList();
 64         for (String label : featurePool) {
 65             Double p = excute(labels, label, dataSize);
 66             ans.add(p);
 67             if (max < p) {
 68                 max_obj = label;
 69                 max = p;
 70             }
 71         }
 72         return max_obj;
 73     }
 74 
 75     /*構造模型*/
 76     public Double excute(final List<String> labels, final String judge, Integer dataSize) {
 77 
 78         Double denominator = 1d;    //分母
 79         Double numerator = 1d;      //分子
 80         Double coughNum = 0d;
 81        /*選擇相關性分子*/
 82         Collection<FeaturePoint> featurePoints = map.get(judge);
 83         for (FeaturePoint featurePoint : featurePoints) {
 84             if (judge.equals(featurePoint.getKey())) {
 85                 coughNum = featurePoint.getP();
 86                 denominator *= (featurePoint.getP() / dataSize);
 87                 break;
 88             }
 89         }
 90        /*O(n^3)*/
 91         Integer size = featurePoints.size() - 1; //容量
 92         for (String label : labels) {
 93             for (FeaturePoint featurePoint : featurePoints) {
 94                 if (label.equals(featurePoint.getKey())) {
 95                     denominator *= (featurePoint.getP() / coughNum);
 96                     for (FeaturePoint feature : map.get(label)) {
 97                         if (label.equals(feature.getKey())) {
 98                             numerator *= (feature.getP() / dataSize);
 99                         }
100                     }
101                 }
102             }
103         }
104 
105         return denominator / numerator;
106     }
107 
108 }
View Code
 1 import com.google.common.collect.Lists;
 2 
 3 import java.util.List;
 4 import java.util.Scanner;
 5 
 6 /**
 7  * *********************************************************
 8  * <p/>
 9  * Author:     XiJun.Gong
10  * Date:       2016-09-01 14:58
11  * Version:    default 1.0.0
12  * Class description:
13  * <p/>
14  * *********************************************************
15  */
16 public class Main {
17 
18     public static void main(String args[]) {
19 
20         Scanner scanner = new Scanner(System.in);
21         Integer size = scanner.nextInt();
22         Integer row = scanner.nextInt();
23         Integer category = scanner.nextInt();
24         while (scanner.hasNext()) {
25             Bayes bayes = new Bayes();
26             for (int ro = 0; ro < row; ro++) {
27                 List<String> list = Lists.newArrayList();
28                 for (int i = 0; i < size; i++) {
29                     list.add(scanner.next());
30                 }
31                 bayes.input(list);
32             }
33             List<String> list = Lists.newArrayList();
34             for (int i = 0; i < size - 1; i++) {
35                 list.add(scanner.next());
36             }
37             for (int i = 0; i < category; i++) {
38                 bayes.add(scanner.next());
39             }
40             System.out.println(bayes.excute(list, row));
41         }
42 
43     }
44 }
View Code

結果:

 1 2 50 2
 2 男  長發
 3 男  短發
 4 男  短發
 5 男  短發
 6 男  短發
 7 男  短發
 8 男  短發
 9 男  短發
10 男  短發
11 男  短發
12 男  短發
13 男  短發
14 男  短發
15 男  短發
16 男  短發
17 男  短發
18 男  短發
19 男  短發
20 男  短發
21 男  長發
22 女  長發
23 女  長發
24 女  長發
25 女  長發
26 女  長發
27 女  長發
28 女  長發
29 女  長發
30 女  長發
31 女  長發
32 女  長發
33 女  長發
34 女  長發
35 女  長發
36 女  長發
37 女  長發
38 女  長發
39 女  長發
40 女  長發
41 女  長發
42 女  長發
43 女  長發
44 女  長發
45 女  長發
46 女  長發
47 女  長發
48 女  長發
49 女  長發
50 女  長發
51 女  長發
52 長發
53 男 女
54
View Code
 1 2 50 2
 2 男  長發
 3 男  短發
 4 男  短發
 5 男  短發
 6 男  短發
 7 男  短發
 8 男  短發
 9 男  短發
10 男  短發
11 男  短發
12 男  短發
13 男  短發
14 男  短發
15 男  短發
16 男  短發
17 男  短發
18 男  短發
19 男  短發
20 男  短發
21 男  長發
22 女  長發
23 女  長發
24 女  長發
25 女  長發
26 女  長發
27 女  長發
28 女  長發
29 女  長發
30 女  長發
31 女  長發
32 女  長發
33 女  長發
34 女  長發
35 女  長發
36 女  長發
37 女  長發
38 女  長發
39 女  長發
40 女  長發
41 女  長發
42 女  長發
43 女  長發
44 女  長發
45 女  長發
46 女  長發
47 女  長發
48 女  長發
49 女  長發
50 女  長發
51 女  長發
52 短發
53 男 女
54
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM