R-数据挖掘-决策树ID3(四)
海林老师《数据挖掘》课程作业系列
要求:自己写R/Python代码、函数实现一系列算法
其他参见:
全文逻辑:(读者可将所有代码按顺序复制到RStudio,全选ctrl+A,运行ctrl+enter,查看结果)
- 分析
- 算法/函数
- 测试数据
- 测试代码
- 测试结果(截图)
分析:这个很难!!(难在递归生成树)
#实现 输入训练集 输出ID3方式得到的决策树(列表)
#输入:训练集数据框(要求最后一列为类别)
#输出:显示划分属性
#返回:列表,按分类属性的可取值分的,
######[[1]]表示为某个属性时,
######列表最里层[[2]]表示此路径的最终类别
算法实现(编写函数):
(1)生成决策树:
id3_jcs<-function(data){
#求给定列的信息熵Info
##data数据框
##默认最后一列为分类Label
info<-function(data){
count_all=nrow(data)
count_fen=as.numeric(table(data[,ncol(data)]))
result=0
for (i in 1:length(count_fen)) {
rate=count_fen[i]/count_all
result=result-(rate*log(rate,2))
}
return(result)
}
#print(info(train))
#求某属性分类,得到的总信息熵
##col_fen:要对该列属性进行分类
info_col<-function(data,col_fen){
leibie=names(table(data[,col_fen]))
count_all=nrow(data)
result=0
for (i in 1:length(leibie)) {
data[,col_fen]=as.character(data[,col_fen])
data_fen=subset(data,data[,col_fen]==leibie[i])
data_fen=droplevels(data_fen)#除去无用因子
count_fen=nrow(data_fen)
info_fen=info(data_fen)
result=result+info_fen*(count_fen/count_all)
}
return(result)
}
#x=info_col(train,1)
#x=info_col(train,"age")
#根据属性各取值,切割数据框
###属性为列名
###返回列表
split_data<-function(data,col_fen){
leibie=names(table(data[,col_fen]))
leibie_num=length(leibie)
data[,col_fen]=as.character(data[,col_fen])
data_fens=NULL
for (i in 1:leibie_num) {
data_fen=subset(data,data[,col_fen]==leibie[i])
data_fen=droplevels(data_fen)
data_fen=data_fen[,-col_fen]
print("------------------------------------")
print(leibie[i])
print("==========================================")
print(data_fen)
data_fens[[i]]=list(leibie[i],data_fen)
}
return(data_fens)
}
#xx=split_data(train,1)
#xx[[1]][[1]]#获取第一个分类依据
#xx[[1]][[2]]#获取第一个分类之后的数据框,不含划分属性
#xx[[2]][[1]]
#test=xx[[2]][[2]]
#a1=best(test)
#test2=split_data(test,a1)
#找到最好的划分属性
###返回所在列数
best<-function(data){
info_all=info(data)#计算父表的信息增益
gains=c(rep(0,ncol(data)))
if(ncol(data)>1){
for (i in 1:(ncol(data)-1)) {
gains[i]=info_all-info_col(data,i)#得到子表的信息增益
}
}
index=which(rank(-gains,ties.method="first")==1)#得到划分属性的列数
return(index)
}
#best(train)
#如果只有一列,而结果却多个的话,采用多数表决
#找出出现次数最多的分类名称
###传入数据框(一列,是最后一列,为分类Y N 等)
majorityCnt<-function(data){
data_fen<-table(data)
index=which(rank(-as.numeric(data_fen),ties.method="first")==1)#得到划分属性的列数
return(names(data_fen)[index])
}
#根据列数和数据框,输出列名
col_name<-function(col_count,data){
return(names(data)[col_count])
}
#############################递归创建树
createTree<-function(data){
#递归停止条件
if(length(data)==1){
return(majorityCnt(data))
}
if(length(as.numeric(table(data[,ncol(data)])))==1){
return(data[1,ncol(data)])
}
if(nrow(data)==0){
return()
}
#根据传入数据求最佳分类属性的位置
bestFeature=best(data)
#输出划分属性的名称
print(paste0("********我是划分属性:",col_name(bestFeature,data),"************"))
#根据这个位置,对数据进行分类,得到分类之后的列表
data_fens=split_data(data,bestFeature)
#遍历分类后的数据框,对每个框进行从上到下相同的操作
data_fen_lists=NULL
for (i in 1:length(data_fens)) {
data_fen_label=c(col_name(bestFeature,data),data_fens[[i]][[1]])
data_fen=data_fens[[i]][[2]]
data_fen_list=createTree(data_fen)
data_fen_lists[[i]]=list(data_fen_label,data_fen_list)
}
return(data_fen_lists)
}
return(createTree(data))
}
(2)实现分类
classify<-function(tree,test){
xh=function(treelist){
for (i in 1:length(treelist)) {
label=treelist[[i]][[1]][1]
if(test[1,label]==treelist[[i]][[1]][2]){
result=treelist[[i]][[2]]
if(!is.list(result)) return(result)
aa=xh(result)
return(aa)
}
}
}
return(xh(tree))
}
数据测试:
测试数据:
书上的数据
训练数据选前13行
最后一行用于测试
age=c("youth","youth","youth","youth","youth","middle_aged","middle_aged","middle_aged","middle_aged","senior","senior","senior","senior","senior")
income=c("high","high","medium","low","medium","high","low","medium","high","medium","low","low","medium","medium")
student=c("no","no","no","yes","yes","no","yes","no","yes","no","yes","yes","yes","no")
credit_rating=c("fair","excellent","fair","fair","excellent","fair","excellent","excellent","fair","fair","fair","excellent","fair","excellent")
class=c("no","no","no","yes","yes","yes","yes","yes","yes","yes","yes","no","yes","no")
data<-data.frame(age,income,student,credit_rating,class,stringsAsFactors = F)
train<-data[1:13,]
test<-data[14:14,]
#训练集测试 结果存在xx中
xx=id3_jcs(train)
#测试集一条记录
test
mmm=classify(xx,test)
mmm
#若测试集多条记录
test2=data[12:14,]
test2
for (i in 1:nrow(test2)) {
print(classify(xx,test2[i,]))
}