參考文章:https://blog.csdn.net/qq_32690999/article/details/78737393
項目代碼目錄結構

模擬訓練的數據集

核心代碼
Bayes.java
package IsStudent_bys; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; public class Bayes { //按類別分類 //輸入:訓練數據(dataSet) //輸出:類別到訓練數據的一個Map public Map<String,ArrayList<ArrayList<String>>> classify(ArrayList<ArrayList<String>> dataSet){ Map<String,ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>(); //待返回的Map int num=dataSet.size(); for(int i=0;i<num;i++) //遍歷所有數據項 { ArrayList<String> Y = dataSet.get(i); //將第i個訓練樣本的信息取出 String Class = Y.get(Y.size()-1).toString(); //約定將類別信息放在最后一個字符串 if(map.containsKey(Class)){ //判斷map中是否已經有這個類了 map.get(Class).add(Y); }else{ //若沒有這個類就新建一個可變長數組記錄並加入map ArrayList<ArrayList<String>> nlist = new ArrayList<ArrayList<String>>(); nlist.add(Y); map.put(Class,nlist); } } return map; } //計算分類后每個類對應的樣本中某個特征出現的概率 //輸入:某一類別對應的數據(classdata) 目標值(value) 相應的列值(index) //輸出:該類數據中相應列上的值等於目標值得頻率 public double CalPro_yj_c(ArrayList<ArrayList<String>> classdata, String value, int index){ int sum = 0; //sum用於記錄相同特征出現的頻數 int num = classdata.size(); for(int i=0;i<num;i++) { ArrayList<String> Y = classdata.get(i); if(Y.get(index).equals(value)) sum++; //相同則計數 } return (double)sum/num; //返回頻率,以頻率帶概率 } //貝葉斯分類器主函數 //輸入:訓練集(可變長數組);待分類集 //輸出:概率最大的類別 public String bys_Main(ArrayList<ArrayList<String>> dataSet, ArrayList<String> testSet){ Map<String, ArrayList<ArrayList<String>>> doc = this.classify(dataSet); //用本class中的分類函數構造映射 Object classes[] = doc.keySet().toArray(); //把map中所有的key取出來(即所有類別) ,借鑒學習了object的使用(待深入了解) double Max_Value=0.0; //最大的概率 int Max_Class=-1; //用於記錄最大類的編號 for(int i=0;i<doc.size();i++) //對每一個類分別計算,本程序只有兩個類 { String c = classes[i].toString(); //將類提取出 ArrayList<ArrayList<String>> y = doc.get(c); //提取該類對應的數據列表 double prob = (double)y.size()/dataSet.size(); //計算比例 System.out.println(c+" : "+prob); //輸出該類的樣本占總樣本個數的比例! for(int j=0;j<testSet.size();j++) //對每個屬性計算先驗概率 { double P_yj_c = CalPro_yj_c(y,testSet.get(j),j); //輸出中間結果以便測試System.out.println("now in bys_Main!!"+P_yj_c); prob = prob*P_yj_c; } System.out.printf("P(%s | testcase) * P(testcase) = %f\n",c,prob); //輸出分子的概率大小 if(prob>Max_Value) //更新分子最大概率 { Max_Value=prob; Max_Class=i; } } return classes[Max_Class].toString(); } }
FetchData.java
package IsStudent_bys; import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.StringTokenizer; public class FetchData { //連接數據庫,讀取訓練數據 //輸入:數據庫 //輸出:可變長數組 public ArrayList<ArrayList<String>> fetch_traindata(){ ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>(); //待返回 Connection conn; String driver = "com.mysql.jdbc.Driver"; String url = "jdbc:mysql://localhost:3306/Bayes"; //指向要訪問的數據庫!注意后面跟的是數據庫名稱 String user = "root"; //navicat for sql配置的用戶名 String password = "root"; //navicat for sql配置的密碼 try{ Class.forName(driver); //用class加載動態鏈接庫——驅動程序 conn = DriverManager.getConnection(url,user,password); //利用信息鏈接數據庫 if(!conn.isClosed()) System.out.println("Succeeded connecting to the Database!"); Statement statement = conn.createStatement(); //用statement 來執行sql語句 String sql = "select * from TrainData"; //這是sql語句中的查詢某個表,注意后面的emp是表名!!! ResultSet rs = statement.executeQuery(sql); //用於返回結果 String str = null; while(rs.next()){ //一直讀到最后一條表 ArrayList<String> s= new ArrayList<String>(); str = rs.getString("Sex"); //分別讀取相應欄位的信息加入到可變長數組中 s.add(str); str = rs.getString("tatto"); s.add(str); str = rs.getString("smoking"); s.add(str); str = rs.getString("wearglasses"); s.add(str); str = rs.getString("ridebike"); s.add(str); str = rs.getString("isStudent"); s.add(str); dataSet.add(s); //加入dataSet //System.out.println(s); 輸出中間結果調試 } rs.close(); conn.close(); }catch(ClassNotFoundException e){ //catch不同的錯誤信息,並報錯 System.out.println("Sorry,can`t find the Driver!"); e.printStackTrace(); }catch(SQLException e){ e.printStackTrace(); }catch (Exception e) { e.printStackTrace(); }finally{ System.out.println("數據庫訓練數據讀取成功!"); } return dataSet; } public ArrayList<String> read_testdata(String str) throws IOException //將用戶輸入的一整行字符串分割解析成可變長數組 { ArrayList<String> testdata=new ArrayList<String>(); //待返回 StringTokenizer tokenizer = new StringTokenizer(str); while (tokenizer.hasMoreTokens()) { testdata.add(tokenizer.nextToken()); } return testdata; } }
Main.java
package IsStudent_bys; import java.io.BufferedInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Scanner; public class Main { //主函數,讀取數據庫,並讀入待判定數據,輸出結果 public static void main(String[] args) { FetchData Fdata = new FetchData(); //java對函數的調用要先聲明相應的對象再調用 Bayes bys = new Bayes(); ArrayList<ArrayList<String>> dataSet = null; //訓練數據列表 ArrayList<String> testSet = null; //測試數據 try{ System.out.println("從數據庫讀入訓練數據:"); dataSet = Fdata.fetch_traindata(); //讀取訓練數據集合 System.out.println("請輸入測試數據:"); Scanner cin = new Scanner(new BufferedInputStream(System.in)); //從標准輸入輸出中讀取測試數據 while(cin.hasNext()) //支持多條測試數據讀取 { String str = cin.nextLine(); //先讀入一行 testSet = Fdata.read_testdata(str);//將這一行進行字符串分隔解析后返回可變長數組類型 //System.out.println(testSet); //輸出中間結果 String ans = bys.bys_Main(dataSet, testSet); //調用貝葉斯分類器 if(ans.equals("yes")) System.out.println("Yes!!! 根據已有數據推斷極有可能像是一個學生!"); //輸出結果 else System.out.println("他/她 的特征不像一名學生!"); } cin.close(); }catch (IOException e) { //處理異常 e.printStackTrace(); } } }
運行效果截圖:

