基於貝葉斯算法實現簡單的分類(java)


參考文章:https://blog.csdn.net/qq_32690999/article/details/78737393

項目代碼目錄結構

 

模擬訓練的數據集

 

 核心代碼

Bayes.java

package IsStudent_bys;
 
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
 
public class Bayes {
 
    //按類別分類
    //輸入:訓練數據(dataSet)
    //輸出:類別到訓練數據的一個Map
    public Map<String,ArrayList<ArrayList<String>>> classify(ArrayList<ArrayList<String>> dataSet){
        Map<String,ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>(); //待返回的Map
        int num=dataSet.size();
        for(int i=0;i<num;i++)  //遍歷所有數據項
        {
            ArrayList<String> Y = dataSet.get(i);  //將第i個訓練樣本的信息取出
            String Class = Y.get(Y.size()-1).toString();  //約定將類別信息放在最后一個字符串
            
            if(map.containsKey(Class)){  //判斷map中是否已經有這個類了
                map.get(Class).add(Y);
            }else{  //若沒有這個類就新建一個可變長數組記錄並加入map
                ArrayList<ArrayList<String>> nlist = new ArrayList<ArrayList<String>>();
                nlist.add(Y);
                map.put(Class,nlist);
            }
        }
        return map;
    }
    
    //計算分類后每個類對應的樣本中某個特征出現的概率
    //輸入:某一類別對應的數據(classdata) 目標值(value) 相應的列值(index)
    //輸出:該類數據中相應列上的值等於目標值得頻率
    public double CalPro_yj_c(ArrayList<ArrayList<String>> classdata, String value, int index){
        int sum = 0;  //sum用於記錄相同特征出現的頻數
        int num = classdata.size();
        for(int i=0;i<num;i++)
        {
            ArrayList<String> Y = classdata.get(i);
            if(Y.get(index).equals(value)) sum++;  //相同則計數
        }
        return (double)sum/num;  //返回頻率,以頻率帶概率
        
    }
    
    //貝葉斯分類器主函數
    //輸入:訓練集(可變長數組);待分類集
    //輸出:概率最大的類別
    public String bys_Main(ArrayList<ArrayList<String>> dataSet, ArrayList<String> testSet){
        Map<String, ArrayList<ArrayList<String>>> doc = this.classify(dataSet);  //用本class中的分類函數構造映射
        
        Object classes[] = doc.keySet().toArray(); //把map中所有的key取出來(即所有類別) ,借鑒學習了object的使用(待深入了解)
        double Max_Value=0.0; //最大的概率
        int Max_Class=-1;     //用於記錄最大類的編號
        for(int i=0;i<doc.size();i++)  //對每一個類分別計算,本程序只有兩個類
        {
            String c = classes[i].toString();  //將類提取出
            ArrayList<ArrayList<String>> y = doc.get(c);  //提取該類對應的數據列表
            double prob = (double)y.size()/dataSet.size();  //計算比例
            
            System.out.println(c+" : "+prob);  //輸出該類的樣本占總樣本個數的比例!
            
            for(int j=0;j<testSet.size();j++)  //對每個屬性計算先驗概率
            {
                double P_yj_c = CalPro_yj_c(y,testSet.get(j),j);
                //輸出中間結果以便測試System.out.println("now in bys_Main!!"+P_yj_c);
                prob = prob*P_yj_c;
            }
            
            System.out.printf("P(%s | testcase) * P(testcase) = %f\n",c,prob);  //輸出分子的概率大小
            if(prob>Max_Value)  //更新分子最大概率
            {
                Max_Value=prob;
                Max_Class=i;
            }
        }
        return classes[Max_Class].toString();
    }
}

FetchData.java

package IsStudent_bys;
 
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.StringTokenizer;
 
public class FetchData {
 
    //連接數據庫,讀取訓練數據
    //輸入:數據庫
    //輸出:可變長數組
    public ArrayList<ArrayList<String>> fetch_traindata(){
        ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();  //待返回
        
        Connection conn;    
        String driver = "com.mysql.jdbc.Driver"; 
        String url = "jdbc:mysql://localhost:3306/Bayes";  //指向要訪問的數據庫!注意后面跟的是數據庫名稱
        String user = "root";   //navicat for sql配置的用戶名
        String password = "root";  //navicat for sql配置的密碼
        try{
            Class.forName(driver);  //用class加載動態鏈接庫——驅動程序
            conn = DriverManager.getConnection(url,user,password);  //利用信息鏈接數據庫
            if(!conn.isClosed())
                System.out.println("Succeeded connecting to the Database!");
            
            Statement statement = conn.createStatement();  //用statement 來執行sql語句
            String sql = "select * from TrainData";   //這是sql語句中的查詢某個表,注意后面的emp是表名!!!
            ResultSet rs = statement.executeQuery(sql);  //用於返回結果
            
            String str = null;
            while(rs.next()){  //一直讀到最后一條表
                ArrayList<String> s= new ArrayList<String>();
                str = rs.getString("Sex");  //分別讀取相應欄位的信息加入到可變長數組中
                s.add(str);
                str = rs.getString("tatto");
                s.add(str);
                str = rs.getString("smoking");
                s.add(str);
                str = rs.getString("wearglasses");
                s.add(str);
                str = rs.getString("ridebike");
                s.add(str);
                str = rs.getString("isStudent");
                s.add(str);
                dataSet.add(s);  //加入dataSet
                //System.out.println(s);  輸出中間結果調試
            }
            rs.close();
            conn.close();
        }catch(ClassNotFoundException e){  //catch不同的錯誤信息,並報錯
            System.out.println("Sorry,can`t find the Driver!");
            e.printStackTrace();
        }catch(SQLException e){
            e.printStackTrace();
        }catch (Exception e) {
            e.printStackTrace();
        }finally{
            System.out.println("數據庫訓練數據讀取成功!");
        }
        return dataSet;
    }
    
    
    public ArrayList<String> read_testdata(String str) throws IOException  //將用戶輸入的一整行字符串分割解析成可變長數組
    {
        ArrayList<String> testdata=new ArrayList<String>();  //待返回
        StringTokenizer tokenizer = new StringTokenizer(str);  
        while (tokenizer.hasMoreTokens()) { 
            testdata.add(tokenizer.nextToken());
        }
        return testdata;
    }
}

Main.java

package IsStudent_bys;
 
import java.io.BufferedInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;
 
public class Main {
 
    //主函數,讀取數據庫,並讀入待判定數據,輸出結果
    public static void main(String[] args) {
        FetchData Fdata = new FetchData();   //java對函數的調用要先聲明相應的對象再調用
        Bayes bys = new Bayes();
        ArrayList<ArrayList<String>> dataSet = null; //訓練數據列表
        ArrayList<String> testSet = null; //測試數據
        try{
            System.out.println("從數據庫讀入訓練數據:");
            dataSet = Fdata.fetch_traindata();   //讀取訓練數據集合
            System.out.println("請輸入測試數據:"); 
            Scanner cin = new Scanner(new BufferedInputStream(System.in));  //從標准輸入輸出中讀取測試數據
            while(cin.hasNext())  //支持多條測試數據讀取
            {
                String str = cin.nextLine();   //先讀入一行
                testSet = Fdata.read_testdata(str);//將這一行進行字符串分隔解析后返回可變長數組類型
                //System.out.println(testSet);  //輸出中間結果
                String ans = bys.bys_Main(dataSet, testSet);  //調用貝葉斯分類器
                if(ans.equals("yes")) System.out.println("Yes!!! 根據已有數據推斷極有可能像是一個學生!");  //輸出結果
                else System.out.println("他/她 的特征不像一名學生!");
            }
            cin.close();
        }catch (IOException e) {  //處理異常
            e.printStackTrace();
        } 
    }
 
}

 

運行效果截圖:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM