决策树(Decision Tree)

dt

原理解读

  决策树(Decision Tree):是在已知各种情况发生概率的基础上,直观运用概率分析的一种图解法。决策树是一种树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。

核心思想

树的构建

  • 步骤1:将所有的数据看成是一个节点(根节点),进入步骤2
  • 步骤2:根据划分准则,从所有属性中挑选一个对节点进行分割,进入步骤3
  • 步骤3:生成若干个子节点,对每一个子节点进行判断,如果满足停止分裂的条件,进入步骤4;否则,进入步骤2
  • 步骤4:设置该节点是叶子节点,其输出的结果为该节点数量占比最大的类别

划分准则

信息熵

信息熵:假设样本集合D中第k类样本所占的比例为$p_k(k=1,2,\ldots,y)$,则D的信息熵定义为:
$$Ent(D)=-\displaystyle \sum_{k=1}^y p_klog_2p_k$$
Ent(D)的值越小,则D的纯度越高

信息增益(ID3)

假设离散属性a有V个不同的取值$\lbrace a^1,a^2,\ldots,a^V \rbrace$,若使用a来对样本集D进行划分,则会产生V个分支节点,其中第v个分支节点包含了D中所有在属性a上取值为$a^v$的样本,记为$D^v$。我们可以计算出$D^v$的信息熵,再给分支结点赋予权重$\frac{\left| D^v \right|}{\left| D \right|}$
则可以计算出用属性a对样本集D进行划分所获得的”信息增益”。
$$Gain(D,a)=Ent(D)-\displaystyle \sum_{v=1}^V \frac{\left| D^v \right|}{\left| D \right|}Ent(D^v)$$
一般而言,信息增益越大,则意味着使用属性a来进行划分所获得的纯度提升越大,ID3决策树学习算法就是以信息增益为准则来划分属性

增益率(C4.5)

实际上,信息增益准则对可取值数目较多的属性有所偏好,为减少这种偏好可能带来的不利影响,著名的C4.5决策树算法不直接使用信息增益,而是使用”增益率”来选择最优划分属性。
增益率定义为
$$Gain_ratio=\frac{Gain(D,a)}{IV(a)}$$
$$IV(a)=-\displaystyle \sum_{v=1}^V \frac{\left| D^v \right|}{\left| D \right|} log_2 \frac{\left| D^v \right|}{\left| D \right|}$$
需要注意的是,增益率准则对可取值数目较少的属性有所偏好,因此C4.5算法先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的

基尼指数(CART)

基尼指数反应了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因此基尼指数越小,则数据集D的纯度越高
$$Gini(D)=1-\displaystyle \sum_{k=1}^y {p_k}^2$$
属性a的基尼指数定义为
$$Gini_index(D,a)=\displaystyle \sum_{v=1}^V \frac{\left| D^v \right|}{\left| D \right|} Gini(D^v)$$
在选择属性集合时,选择使划分后基尼指数最小的属性作为最优划分属性

剪枝处理

剪枝(pruning)是决策树学习算法对付”过拟合”的主要手段
决策树剪枝的基本策略有”预剪枝(prepruning)”和”后剪枝(postpruning)”。

预剪枝(prepruning)

预剪枝:是在决策树生长过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶子结点
预剪枝优点:

  • 降低过拟合风险
  • 显著减少决策树的训练时间开销和测试时间开销
    预剪枝缺点:
  • 因为”贪心”本质,可能带来欠拟合的风险

后剪枝(postpruning)

后剪枝:先从训练集生成一颗完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶节点
后剪枝优点:

  • 泛化性能较好
  • 欠拟合风险较小
    后剪枝缺点:
  • 生成完全决策树后进行,并且自底向上对所有非叶结点进行逐一考察,时间开销大

算法流程

DT



代码实战

代码中所用数据为罗斯.昆兰(Ross Quinlan)当年所用的高尔夫模型
其实这里使用的算法并不是ID3,因为老师上课讲错了,误把ID3讲成最小熵,ID3应该是最小熵增益。但是作业要求是实现最小熵的ID3算法,所以我起名为ID3。但是标准的ID3,还需要一个比较的过程,确定哪一种分割是最小熵增益
DATA

ID3_main.m

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
%% %获取基本信息
clear;clc;close all;
%设置类别和标签最长为100字符
char_len=100;
%生成一个长度为100的空串
space(1:char_len)=' ';
%读取文本文档
fo=fopen('data1.txt','rt');
txt=textscan(fo,'%s');
fclose(fo);
%class_name为类别名称,如outlook,temperature等等
class_name=strsplit(txt{1}{1},',');
class_name=class_name(1:end-1);
%class_num为类别数
class_num=length(class_name);
%sample_num为样本数
sample_num=length(txt{1})-1;
data{sample_num,class_num+1}=[];
%读入数据
for i=1:sample_num
temp=strsplit(txt{1}{i+1},',');
for j=1:class_num
data{i,j}=[temp{j},'_',class_name{1,j}];
end
data{i,j+1}=temp{j+1};
end
%class_info存放每一个类别的标签信息,如第一个元胞中存放rain,sunny,overcast
class_info{1,class_num}=[];
for i=1:class_num
temp=unique(data(:,i));
class_info{i}=temp;
end

%% %生成树
No=1;
%创建100*100的字符矩阵存放树的信息。
tree(1:char_len,1:char_len)=' ';
[tree,No] = ID3_creat(data,class_name,tree,No);

%% %构建连接矩阵
%获取树的节点
tree_node=tree(1:2:No,:);
%获取树的标签
tree_label=tree(2:2:No,:);
%创建连接矩阵
vect{size(tree_node,1),size(tree_node,1)}=[];
%vect元胞记录父子关系
for i=1:size(tree_node,1)
tem=find(ismember(class_name,deblank(tree_node(i,:))));
if isempty(tem)
continue;
end
num=size(class_info{1,tem},1);
for j=1:num
temp=space;
temp(1:length(class_info{1,tem}{j}))=class_info{1,tem}{j};
for k=1:size(tree_label,1)
if isequal(tree_label(k,:),temp)
break;
end
end
vect{i,k+1}=temp;
end
end

%% %绘制树图
node=zeros(1,size(tree_node,1));
%根据vect元胞中的父子关系画出树图
for i=1:size(vect,2)
tem=vect(:,i);
for j=1:size(tem,1)
if ~isempty(tem{j})
node(i)=j;
break;
end
end
end
treeplot(node);

%% %写树的类别(节点)
[x,y]=treelayout(node);
x=x';
y=y';
text(x(:,1),y(:,1),tree_node);

%% %写树的标签(枝条)
x1=zeros(size(tree_label,1));
y1=zeros(size(tree_label,1));
%根据父子关系在父子节点中点写入标签
for i=2:length(node)
x1(i-1,1)=(x(i,1)+x(node(i),1))/2;
y1(i-1,1)=(y(i,1)+y(node(i),1))/2;
end
for i=1:size(tree_label)
temp=strsplit(tree_label(i,:),'_');
tree_label(i,:)=[temp{1},space(length(temp{1})+1:end)];
end
text(x1(:,1),y1(:,1),tree_label);

ID3_split.m

1
2
3
4
5
6
7
8
9
10
11
12
13
function bestfeature=ID3_split(data)
%求最小熵的分割算法
numfeatures = size(data,2) -1 ;
bestent = log2(numfeatures);
bestfeature = -1;
for i =1:numfeatures
ent = ID3_ent(data,i);
if ent < bestent
bestent = ent;
bestfeature = i;
end
end

ID3_ent.m

1
2
3
4
5
6
7
8
9
10
11
12
13
14
function ent=ID3_ent(data,i)
%求最小熵
info=tabulate(data(:,i));
ent=0;
for k=1:size(info,1)
loc=ismember(data(:,i),info{k,1});
info1=tabulate(data(loc,end));
temp=0;
for n=1:size(info1,1)
temp=temp-info1{n,3}/100*log2(info1{n,3}/100);
end
ent=ent+info{k,3}/100*temp;
end

ID3_creat.m

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
function [tree,No]=ID3_creat(data,class_name,tree,No)
classlist=data(:,end);
%如果标签全为yes或no则已经分完,返回
if size(tabulate(classlist),1)==1
tree(No,1:length(classlist{1}))=classlist{1};
return
end
%如果没有分完找到最好的特征,递归生成树
bestfeature_loc = ID3_split(data);
bestfeature=class_name{1,bestfeature_loc};
tree(No,1:length(bestfeature))=bestfeature;
featureValues=tabulate(data(:,bestfeature_loc));
for m=1:size(featureValues,1)
tree(No+1,1:length(featureValues{m,1}))=featureValues{m,1};
loc=ismember(data(:,bestfeature_loc),featureValues{m,1});
data1=data(loc,[1:bestfeature_loc-1,bestfeature_loc+1:end]);
[tree,No] = ID3_creat(data1,class_name(:,[1:bestfeature_loc-1,bestfeature_loc+1:end]),tree,No+2);
end



实验结果

DT

ID3, C4.5, CART性能比较

$$ \begin{array}{|c|c|c|c|c|} 算法 & 结构 & 特征选择 & 连续值 & 缺失值 \ \hline ID3&多叉树&信息增益&不支持&不支持\ C4.5&多叉树&信息增益比&支持&支持\ CART&二叉树&基尼系数&支持&支持\ \end{array}$$

决策树分类优缺点

  • 优点:
    • 数据量一般不会太大
    • 具有很强的可解释性
    • 生成的决策树简单直观
    • 可以处理多维度输出的分类问题。
    • 既可以处理离散值也可以处理连续值
    • 可以通过剪枝来权衡欠拟合和过拟合
    • 基本不需要预处理,不需要提前归一化,处理缺失值。
  • 缺点:
    • 树结构受样本影响较大
    • 复杂的模型很难用决策树解决
    • 寻找最优的决策树是一个NP难的问题
    • 如果某些特征的样本比例过大,生成决策树容易偏向于这些特征,这个可以通过调节样本权重来改善
-------------本文结束感谢您的阅读-------------
0%