決策樹算法以及matlab實現ID3算法


本文將詳細介紹ID3算法,其也是最經典的決策樹分類算法。

1、ID3算法簡介及基本原理 
ID3算法基於信息熵來選擇最佳的測試屬性,它選擇當前樣本集中具有最大信息增益值的屬性作為測試屬性;樣本集的划分則依據測試屬性的取值進行,測試屬性有多少個不同的取值就將樣本集划分為多少個子樣本集,同時決策樹上相應於該樣本集的節點長出新的葉子節點。ID3算法根據信息論的理論,采用划分后樣本集的不確定性作為衡量划分好壞的標准,用信息增益值度量不確定性:信息增益值越大,不確定性越小。因此,ID3算法在每個非葉節點選擇信息增益最大的屬性作為測試屬性,這樣可以得到當前情況下最純的划分,從而得到較小的決策樹。

設S是s個數據樣本的集合。假定類別屬性具有m個不同的值:這里寫圖片描述,設這里寫圖片描述是類這里寫圖片描述中的樣本數。對一個給定的樣本,它總的信息熵為這里寫圖片描述,其中,這里寫圖片描述是任意樣本屬於這里寫圖片描述的概率,一般可以用這里寫圖片描述估計。

設一個屬性A具有k個不同的值這里寫圖片描述,利用屬性A將集合S划分為k個子集這里寫圖片描述,其中這里寫圖片描述包含了集合S中屬性A取這里寫圖片描述值的樣本。若選擇屬性A為測試屬性,則這些子集就是從集合S的節點生長出來的新的葉節點。設這里寫圖片描述是子集這里寫圖片描述中類別為這里寫圖片描述的樣本數,則根據屬性A划分樣本的信息熵為這里寫圖片描述 
其中,這里寫圖片描述這里寫圖片描述是子集這里寫圖片描述中類別為這里寫圖片描述的樣本的概率。

最后,用屬性A划分樣本集S后所得的信息增益(Gain)為這里寫圖片描述

顯然這里寫圖片描述越小,Gain(A)的值就越大,說明選擇測試屬性A對於分類提供的信息越大,選擇A之后對分類的不確定程度越小。屬性A的k個不同的值對應的樣本集S的k個子集或分支,通過遞歸調用上述過程(不包括已經選擇的屬性),生成其他屬性作為節點的子節點和分支來生成整個決策樹。ID3決策樹算法作為一個典型的決策樹學習算法,其核心是在決策樹的各級節點上都用信息增益作為判斷標准來進行屬性的選擇,使得在每個非葉子節點上進行測試時,都能獲得最大的類別分類增益,使分類后的數據集的熵最小。這樣的處理方法使得樹的平均深度較小,從而有效地提高了分類效率。

2、ID3算法的具體流程 
ID3算法的具體流程如下: 
1)對當前樣本集合,計算所有屬性的信息增益; 
2)選擇信息增益最大的屬性作為測試屬性,把測試屬性取值相同的樣本划為同一個子樣本集; 
3)若子樣本集的類別屬性只含有單個屬性,則分支為葉子節點,判斷其屬性值並標上相應的符號,然后返回調用處;否則對子樣本集遞歸調用本算法。

數據如圖所示

序號  天氣  是否周末    是否有促銷   銷量
1   壞   是   是   高
2   壞   是   是   高
3   壞   是   是   高
4   壞   否   是   高
5   壞   是   是   高
6   壞   否   是   高
7   壞   是   否   高
8   好   是   是   高
9   好   是   否   高
10  好   是   是   高
11  好   是   是   高
12  好   是   是   高
13  好   是   是   高
14  壞   是   是   低
15  好   否   是   高
16  好   否   是   高
17  好   否   是   高
18  好   否   是   高
19  好   否   否   高
20  壞   否   否   低
21  壞   否   是   低
22  壞   否   是   低
23  壞   否   是   低
24  壞   否   否   低
25  壞   是   否   低
26  好   否   是   低
27  好   否   是   低
28  壞   否   否   低
29  壞   否   否   低
30  好   否   否   低
31  壞   是   否   低
32  好   否   是   低
33  好   否   否   低
34  好   否   否   低

采用ID3算法構建決策樹模型的具體步驟如下: 
1)根據公式這里寫圖片描述,計算總的信息熵,其中數據中總記錄數為34,而銷售數量為“高”的數據有18,“低”的有16 
這里寫圖片描述

2)根據公式這里寫圖片描述這里寫圖片描述,計算每個測試屬性的信息熵。

對於天氣屬性,其屬性值有“好”和“壞”兩種。其中天氣為“好”的條件下,銷售數量為“高”的記錄為11,銷售數量為“低”的記錄為6,可表示為(11,6);天氣為“壞”的條件下,銷售數量為“高”的記錄為7,銷售數量為“低”的記錄為10,可表示為(7,10)。則天氣屬性的信息熵計算過程如下: 
這里寫圖片描述 
這里寫圖片描述 
這里寫圖片描述

對於是否周末屬性,其屬性值有“是”和“否”兩種。其中是否周末屬性為“是”的條件下,銷售數量為“高”的記錄為11,銷售數量為“低”的記錄為3,可表示為(11,3);是否周末屬性為“否”的條件下,銷售數量為“高”的記錄為7,銷售數量為“低”的記錄為13,可表示為(7,13)。則節假日屬性的信息熵計算過程如下: 
這里寫圖片描述 
這里寫圖片描述 
這里寫圖片描述

對於是否有促銷屬性,其屬性值有“是”和“否”兩種。其中是否有促銷屬性為“是”的條件下,銷售數量為“高”的記錄為15,銷售數量為“低”的記錄為7,可表示為(15,7);其中是否有促銷屬性為“否”的條件下,銷售數量為“高”的記錄為3,銷售數量為“低”的記錄為9,可表示為(3,9)。則是否有促銷屬性的信息熵計算過程如下: 
這里寫圖片描述 
這里寫圖片描述 
這里寫圖片描述

根據公式這里寫圖片描述,計算天氣、是否周末和是否有促銷屬性的信息增益值。 
這里寫圖片描述 
這里寫圖片描述 
這里寫圖片描述

3)由計算結果可以知道是否周末屬性的信息增益值最大,它的兩個屬性值“是”和“否”作為該根節點的兩個分支。然后按照上面的步驟繼續對該根節點的兩個分支進行節點的划分,針對每一個分支節點繼續進行信息增益的計算,如此循環反復,直到沒有新的節點分支,最終構成一棵決策樹。

由於ID3決策樹算法采用了信息增益作為選擇測試屬性的標准,會偏向於選擇取值較多的即所謂的高度分支屬性,而這類屬性並不一定是最優的屬性。同時ID3決策樹算法只能處理離散屬性,對於連續型的屬性,在分類前需要對其進行離散化。為了解決傾向於選擇高度分支屬性的問題,人們采用信息增益率作為選擇測試屬性的標准,這樣便得到C4.5決策樹的算法。此外常用的決策樹算法還有CART算法、SLIQ算法、SPRINT算法和PUBLIC算法等等。

使用ID3算法建立決策樹的MATLAB代碼如下所示 
ID3_decision_tree.m

%% 使用ID3決策樹算法預測銷量高低
clear ;

%% 數據預處理
disp('正在進行數據預處理...');
[matrix,attributes_label,attributes] =  id3_preprocess();

%% 構造ID3決策樹,其中id3()為自定義函數
disp('數據預處理完成,正在進行構造樹...');
tree = id3(matrix,attributes_label,attributes);

%% 打印並畫決策樹
[nodeids,nodevalues] = print_tree(tree);
tree_plot(nodeids,nodevalues);

disp('ID3算法構建決策樹完成!');

id3_preprocess.m:

function [ matrix,attributes,activeAttributes ] = id3_preprocess(  )
%% ID3算法數據預處理,把字符串轉換為0,1編碼

% 輸出參數:
% matrix: 轉換后的0,1矩陣;
% attributes: 屬性和Label;
% activeAttributes : 屬性向量,全1;

%% 讀取數據
txt = {  '序號'    '天氣'    '是否周末'    '是否有促銷'    '銷量'
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  
        ''        ''      ''          ''            ''  }
attributes=txt(1,2:end);
activeAttributes = ones(1,length(attributes)-1);
data = txt(2:end,2:end);

%% 針對每列數據進行轉換
[rows,cols] = size(data);
matrix = zeros(rows,cols);
for j=1:cols
    matrix(:,j) = cellfun(@trans2onezero,data(:,j));
end

end

function flag = trans2onezero(data)
    if strcmp(data,'') ||strcmp(data,'')...
        ||strcmp(data,'')
        flag =0;
        return ;
    end
    flag =1;
end

id3.m:

function [ tree ] = id3( examples, attributes, activeAttributes )
%% ID3 算法 ,構建ID3決策樹
    ...參考:https://github.com/gwheaton/ID3-Decision-Tree

% 輸入參數:
% example: 輸入0、1矩陣;
% attributes: 屬性值,含有Label;
% activeAttributes: 活躍的屬性值;-1,1向量,1表示活躍;

% 輸出參數:
% tree:構建的決策樹;

%% 提供的數據為空,則報異常
if (isempty(examples));
    error('必須提供數據!');
end

% 常量
numberAttributes = length(activeAttributes);
numberExamples = length(examples(:,1));

% 創建樹節點
tree = struct('value', 'null', 'left', 'null', 'right', 'null');

% 如果最后一列全部為1,則返回“true”
lastColumnSum = sum(examples(:, numberAttributes + 1));

if (lastColumnSum == numberExamples);
    tree.value = 'true';
    return
end
% 如果最后一列全部為0,則返回“falseif (lastColumnSum == 0);
    tree.value = 'false';
    return
end

% 如果活躍的屬性為空,則返回label最多的屬性值
if (sum(activeAttributes) == 0);
    if (lastColumnSum >= numberExamples / 2);
        tree.value = 'true';
    else
        tree.value = 'false';
    end
    return
end

%% 計算當前屬性的熵
p1 = lastColumnSum / numberExamples;
if (p1 == 0);
    p1_eq = 0;
else
    p1_eq = -1*p1*log2(p1);
end
p0 = (numberExamples - lastColumnSum) / numberExamples;
if (p0 == 0);
    p0_eq = 0;
else
    p0_eq = -1*p0*log2(p0);
end
currentEntropy = p1_eq + p0_eq;

%% 尋找最大增益
gains = -1*ones(1,numberAttributes); % 初始化增益

for i=1:numberAttributes;
    if (activeAttributes(i)) % 該屬性仍處於活躍狀態,對其更新
        s0 = 0; s0_and_true = 0;
        s1 = 0; s1_and_true = 0;
        for j=1:numberExamples;
            if (examples(j,i)); 
                s1 = s1 + 1;
                if (examples(j, numberAttributes + 1)); 
                    s1_and_true = s1_and_true + 1;
                end
            else
                s0 = s0 + 1;
                if (examples(j, numberAttributes + 1)); 
                    s0_and_true = s0_and_true + 1;
                end
            end
        end

        % 熵 S(v=1)
        if (~s1);
            p1 = 0;
        else
            p1 = (s1_and_true / s1); 
        end
        if (p1 == 0);
            p1_eq = 0;
        else
            p1_eq = -1*(p1)*log2(p1);
        end
        if (~s1);
            p0 = 0;
        else
            p0 = ((s1 - s1_and_true) / s1);
        end
        if (p0 == 0);
            p0_eq = 0;
        else
            p0_eq = -1*(p0)*log2(p0);
        end
        entropy_s1 = p1_eq + p0_eq;

        % 熵 S(v=0)
        if (~s0);
            p1 = 0;
        else
            p1 = (s0_and_true / s0); 
        end
        if (p1 == 0);
            p1_eq = 0;
        else
            p1_eq = -1*(p1)*log2(p1);
        end
        if (~s0);
            p0 = 0;
        else
            p0 = ((s0 - s0_and_true) / s0);
        end
        if (p0 == 0);
            p0_eq = 0;
        else
            p0_eq = -1*(p0)*log2(p0);
        end
        entropy_s0 = p1_eq + p0_eq;

        gains(i) = currentEntropy - ((s1/numberExamples)*entropy_s1) - ((s0/numberExamples)*entropy_s0);
    end
end

% 選出最大增益
[~, bestAttribute] = max(gains);
% 設置相應值
tree.value = attributes{bestAttribute};
% 去活躍狀態
activeAttributes(bestAttribute) = 0;

% 根據bestAttribute把數據進行分組
examples_0= examples(examples(:,bestAttribute)==0,:);
examples_1= examples(examples(:,bestAttribute)==1,:);

% 當 value = false or 0, 左分支
if (isempty(examples_0));
    leaf = struct('value', 'null', 'left', 'null', 'right', 'null');
    if (lastColumnSum >= numberExamples / 2); % for matrix examples
        leaf.value = 'true';
    else
        leaf.value = 'false';
    end
    tree.left = leaf;
else
    % 遞歸
    tree.left = id3(examples_0, attributes, activeAttributes);
end
% 當 value = true or 1, 右分支
if (isempty(examples_1));
    leaf = struct('value', 'null', 'left', 'null', 'right', 'null');
    if (lastColumnSum >= numberExamples / 2); 
        leaf.value = 'true';
    else
        leaf.value = 'false';
    end
    tree.right = leaf;
else
    % 遞歸
    tree.right = id3(examples_1, attributes, activeAttributes);
end

% 返回
return
end

print_tree.m:

function [nodeids_,nodevalue_] = print_tree(tree)
%% 打印樹,返回樹的關系向量
global nodeid nodeids nodevalue;
nodeids(1)=0; % 根節點的值為0
nodeid=0;
nodevalue={};
if isempty(tree) 
    disp('空樹!');
    return ;
end

queue = queue_push([],tree);
while ~isempty(queue) % 隊列不為空
     [node,queue] = queue_pop(queue); % 出隊列

     visit(node,queue_curr_size(queue));
     if ~strcmp(node.left,'null') % 左子樹不為空
        queue = queue_push(queue,node.left); % 進隊
     end
     if ~strcmp(node.right,'null') % 左子樹不為空
        queue = queue_push(queue,node.right); % 進隊
     end
end

%% 返回 節點關系,用於treeplot畫圖
nodeids_=nodeids;
nodevalue_=nodevalue;
end

function visit(node,length_)
%% 訪問node 節點,並把其設置值為nodeid的節點
    global nodeid nodeids nodevalue;
    if isleaf(node)
        nodeid=nodeid+1;
        fprintf('葉子節點,node: %d\t,屬性值: %s\n', ...
        nodeid, node.value);
        nodevalue{1,nodeid}=node.value;
    else % 要么是葉子節點,要么不是
        %if isleaf(node.left) && ~isleaf(node.right) % 左邊為葉子節點,右邊不是
        nodeid=nodeid+1;
        nodeids(nodeid+length_+1)=nodeid;
        nodeids(nodeid+length_+2)=nodeid;

        fprintf('node: %d\t屬性值: %s\t,左子樹為節點:node%d,右子樹為節點:node%d\n', ...
        nodeid, node.value,nodeid+length_+1,nodeid+length_+2);
        nodevalue{1,nodeid}=node.value;
    end
end

function flag = isleaf(node)
%% 是否是葉子節點
    if strcmp(node.left,'null') && strcmp(node.right,'null') % 左右都為空
        flag =1;
    else
        flag=0;
    end
end

tree_plot.m

function tree_plot( p ,nodevalues)
%% 參考treeplot函數

[x,y,h]=treelayout(p);
f = find(p~=0);
pp = p(f);
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];

X = X(:);
Y = Y(:);

    n = length(p);
    if n < 500,
        hold on ; 
        plot (x, y, 'ro', X, Y, 'r-');
        nodesize = length(x);
        for i=1:nodesize
%            text(x(i)+0.01,y(i),['node' num2str(i)]); 
            text(x(i)+0.01,y(i),nodevalues{1,i}); 
        end
        hold off;
    else
        plot (X, Y, 'r-');
    end;

xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);

end

queue_push.m

function [ newqueue ] = queue_push( queue,item )
%% 進隊

% cols = size(queue);
% newqueue =structs(1,cols+1);
newqueue=[queue,item];

end

queue_pop.m

function [ item,newqueue ] = queue_pop( queue )
%% 訪問隊列

if isempty(queue)
    disp('隊列為空,不能訪問!');
    return;
end

item = queue(1); % 第一個元素彈出
newqueue=queue(2:end); % 往后移動一個元素位置

end

queue_curr_size.m

function [ length_ ] = queue_curr_size( queue )
%% 當前隊列長度

length_= length(queue);

end

轉載自https://blog.csdn.net/lfdanding/article/details/50753239


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM