前言:   

     决策树也是一种简单的分类模型,主要应用在反信用卡诈骗,骚扰邮件过滤。

这里主要结合:

    1: 乳腺癌预测例子

     2  破腹产预测例子

         https://archive.ics.uci.edu/ml/datasets/Caesarian+Section+Classification+Dataset

   3   学生购买电脑的例子讲述一下相关的原理

       生成的树如下:

     

 

   

   

 

数据集2,预测是否会购买电脑

   

    

 

目录:

  1.      决策树简介
  2.      决策树算法流程
  3.      决策树三种主要算法(ID3,C4.5, CART)
  4.      决策树剪枝
  5.      例子
  6.      问题

一  决策树简介

     垃圾邮件分类决策树模型:

    

    

   决策树主要由四部分组成:

      根节点(属性)

      子节点(属性)

       分枝(属性值)

      叶子节点(标签

二  决策树流程:

        2.1 CreateNode(创建节点)

              

  

     2.2  总体流程:

           

 

    2.3 熵

        反应离散随机事件出现的概率

       

       熵也是一个有约束的凸函数

 

     

     可以通过琴声不等式证明p 取1/n 时候值最大,且范围值为

     证明:

       由定理1:

       

       其中

       

        因为是上凸,所有

       

       

        

      

        可以看出下面三点

  1:     

   2:       时候为熵的极大值

   3:         分类种类越多,熵值越大。

 

  

      

         

   2.4 信息增益

         假设有K类,A代表具体属性,比如上面 年龄,D_i 是按照年龄中一个具体的划分,比如老年

取的数据集

       

          其中:

         

              

  2.5 信息增益比

          信息增益划分,会偏向选择较多特征的一列,为了纠正这一问题,可以选择信息增益比

           

              

 

    

三  决策树三种主要算法

     3.1 ID3

           输入:

                   训练数据集D , 和 特征集A,  阀值

           输出:

                   决策树

 

           步骤:

                   1: 若D中所有的实例都是同一类,则T为单节点树,作为叶节点返回

                   2: 若特征集为空集, 则T为单节点树,将D中实例数最多的类作为叶节点返回

                   3: 计算特征集A中信息增益,选择信息增益最大的特征Ag

                    3: 计算特征集A中信息增益比,选择信息增益比最大的特征Ag

                   4:   如果该信息增益小于阀值e,  则T为单节点树,返回

                    5   对A中每个属性a,将D 按照a 划分为若干个非空集合Di, 构建新的节点树,该节点和其子节点构成

                 树T,返回T     例如:{A2:{a21:value, a22:{...},...}

                     6: 对第i个子节点,以Di为训练集,以A-{Ag}为特征集,递归上面步骤1--5

                 

              

          

     3.2 C4.5

             

                     唯一的区别选择了信息增益比

 

     3.3 CART

            最小二乘回归树生成算法

            可以对连续数据处理

             输入:

                      训练集D

             输出:

                     回归树f(x),也是二分树

           

             1: 选择最优切分变量j 和 切分点s, 求解

                      

                      遍历遍历j,对固定的切分遍历j扫描切分点s, 选择使得上面达到最小的(j,s)

                    

 

              2: 对选定的对(j,s)划分区域 ,并决定对应的输出值

                

             3:  继续对两个子区域调用步骤1,2 直到满足停止条件

              4: 将输入空间划分为M个区域 生成决策树

                     

 

    停止条件:

           当前的每个叶节点方差和小于一定值,例如上面乳腺癌按照年龄分,只分一次

         

      

"""
根据最小标准差方差获得最优的切分点
原理上就是数据尽量是同类
Args
   j: 属性,某一列
   dataList: 数据集

retrun
    j, s 最优切分点,这里考虑的只是连续数据
"""
def  MiniTree(j, dataList):
    
    #print("j: ",j, "\n data",dataList)
    sList =[int(item[j]) for item in dataList]
    sList.sort()
    setS = set(sList)
    print("\n ******setS******** ",setS )
    miniVar = float("inf")   
    miniS = 0
    leftR  =  []
    rightR =  []
    for s in setS:
        
        leftR.clear()
        rightR.clear()
        for item in dataList:
            a = int(item[j])
            y = int(item[-1])
            if a <s:    
                leftR.append(y)
            else:
                rightR.append(y)
         
        var1 = 0
        var2 = 0
        C1 = None
        C2 = None
        if len(leftR)>0:
            var1 = np.var(leftR)*len(leftR)
            C1 = np.mean(leftR)
                
        if len(rightR)>0:
            var2 = np.var(rightR)*len(rightR)
            C2 = np.mean(rightR)
        var = var1+var2
        #print("s: ",s, "\t currQ ",var)
        if var <miniVar:       
                miniVar = var
                miniS = s
                
                print("\n  miniS: \t  ",s,"\t C1: ",C1, "\t C2: ",C2)

 

   
                

      基尼指数      : CART中用来选择最优特征,反应样本集合的不确定性

       

        

 

        如何证明Gini 极大值是

       证明:

                根据约束条件,加上一个仿射函数,配成

              

        对        求偏导数

               

                 

     重新带入上式

            

          对 求偏导数, 则

             

            

         则:

              为极值点

 

     CART生成算法:

       输入:

           训练数据集D,停止计算条件

      输出:

          CART决策树

 

        s1: 设结点的训练数据集为D,计算现有特征对该数据集的基尼系数,对每一个特征A,对每一个取值a

             根据样本点取值对A=a ,将数据集分割成D1,D2两部分

        s2:  在所有可能的特征A 以及它们的切分点a,选择基尼指数最小的特征A及对应的切分点a作为最优切分点

      生成两个子结点,将数据集划分到对应的两个子结点中去,

        s3: 对两个子结点,递归调用s1,s2,直到满足停止条件

         s4: 生成CART树

             

              

四  剪枝   

  

    主要防止过拟合,增加泛化能力

 

    分为预剪枝,和后剪枝

    常见的后剪枝算法为: REP, PEP, CCP等算法


  4.1  决策树剪枝:  

   输入:

          算法生成整个树T, 参数

    输出:

           修剪后的树 Tn

   

    4.1 计算每个结点的经验熵

    4.2  递归从树的叶结点回缩(结点)

    4.3  设一组叶结点回缩前和回缩后整树分别为,对应的损失函数分别为

        如果,则进行剪枝,父结点变成新的叶结点

     4.4  返回4.2 直到不能继续

 

    损失函数:T为叶节点个数

      

    经验熵:

       

      则上面也可以写为:

       

         

     算法里面:

        需要先找到叶节点,然后递归回父节点,遍历出同一个父节点下面的其它叶节点

      才能得到剪枝后T数目

 

    

          剪枝后

          


  4.2   CART 剪枝

    

          子树序列形成:

                 

             对于固定的 ,存在不同的剪枝方案,但存在一个最优的子树

          

          以 t为单结点的子树t损失函数:

        

         以 t 为根结点损失函数为:

          

          可以调整参数a,使得二者相等

          

 

                此刻:

           表示剪枝后,整体损失函数减少程度

 

              在剪枝得到的子树序列T0,T1.....Tn中交叉验证选取最优子树

 

              利用独立数据集,测试子树序列的基尼系数,或者平方误差,这里面的每一轮生成一个最优的Tree

下一轮在前一轮基础上,重新剪枝

 

         算法流程:


          输入:

                 CART算法生成的决策树T0

          输出:

                  最优决策树

 

          1 : 设k = 0, T= T0

          2:  

          3     自下而上对内部各个结点t计算以及

               

 

             这里Tt表示以t为根结点的子树,C(Tt)是对训练集的预测误差,|T_t|是Tt的叶结点个数

 

             4:

的内部结点的子树,进行剪枝,以多数表决法得到树

 

   5: 设k=k+1,,

 

   6: Tk 不是由根结点以及两个叶结点构成的树,否则Tk= Tn

   7: 采用交叉验证得到子树T0,T1,...Tn选择最优子树

 这里注意:

                        剪枝后,一定要保证各个叶结点的精度。

 

 

               

        代码实现:

     为了方便,生成树后先给不同叶结点打上唯一的Tag

    

          

 """
    给叶结点打上唯一的标签
    
    """
    def AddTag(self, tree):
        
        keys =['left', 'right']
       
        for key in keys:
            subTree = tree[key]   
            if self.IsTree(subTree):
                self.AddTag(subTree)
            else:
                Tag = "TAG_"+key
                if Tag not in tree:
                    x = self.LeafTag.pop(0)
                    tree[Tag] = x
                

        

    这里面给出一种单轮选择最优Gini系数的方案,不同轮即上面k=k+1,最好使用不同的数据集

   同时也要注意,剪枝后,要使用数据集对新生成的Tree验证

   

   """
    剪枝
    Args
        dataList: 数据集
        tree:  树
    return
        None
    """
    def Prunch(self, dataList, dataLabel,tree):
        
        ct = 0.0
        c_Beaf = 0.0
        
        
        
        if  not self.IsTree(tree) or len(dataLabel)<1:
             return
            
        BeafNum = self.GetLeafs(tree) ##叶结点数目|T|
        ct = self.GetGini(dataLabel)
        if BeafNum<2: ###本身已经是叶结点了
            return 
        
        dict_beaf={}
        i = 0
        for  data in dataList: 
            perdict_Label, Tag_ID = self.Perdict(tree,data)
            
            if Tag_ID not in dict_beaf:
                dict_beaf[Tag_ID] =[]
            test_label = dataLabel[i]
            beaf_List = dict_beaf[Tag_ID]
            beaf_List.append(test_label)
            i = i+1
        #keyList = dict_beaf.keys()
       # print("\n ****keyList***\n\t ",keyList, "\t  len: ",len(keyList))
       
       
        c_Beaf = self.GetBeafGini(dict_beaf)
        alpha = (ct-c_Beaf)/(BeafNum-1)
        
        
        
        
       
        name = tree['name']
        feat = self.train_feature.index(name)
        val = tree['feature']
       # print("\n ct:  \t",ct, "\t c_Beaf: \t",c_Beaf, "\t alpha: ",alpha)
        print("\n name: ",name , "\t  val : ",val)
        print("\n t为根结点Gini :\t %.3f"%ct, "\t t为树的Gini:\t %.3f"%c_Beaf, "\t alpha: %.3f"%alpha)
        L_Data, L_Label, R_Data, R_Label =self.SplitData(dataList, feat, val, dataLabel)
        LTree= tree['left']
        RTree =tree['right']
        self.Prunch(L_Data,L_Label ,LTree)
        self.Prunch(R_Data, R_Label,RTree )
                
        
        
       
        return None
    

 name:  worst perimeter           val :  106.2

 t为根结点Gini :     0.371   t为树的Gini:       0.089   alpha: 0.017

 name:  worst smoothness          val :  0.1777

 t为根结点Gini :     0.042   t为树的Gini:       0.032   alpha: 0.001

 name:  worst concave points      val :  0.1607

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.000

 name:  mean area         val :  698.8

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.000

 name:  perimeter error           val :  4.138

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.000

 name:  worst texture     val :  30.25

 t为根结点Gini :     0.022   t为树的Gini:       0.019   alpha: 0.001

 name:  mean fractal dimension    val :  0.05628

 t为根结点Gini :     0.124   t为树的Gini:       0.114   alpha: 0.003

 name:  smoothness error          val :  0.007499

 t为根结点Gini :     0.077   t为树的Gini:       0.073   alpha: 0.002

 name:  mean radius       val :  12.77

 t为根结点Gini :     0.000   t为树的Gini:       0.000   alpha: 0.000

 name:  mean radius       val :  12.77

 t为根结点Gini :     0.000   t为树的Gini:       0.000   alpha: 0.000

 name:  mean texture      val :  15.56

 t为根结点Gini :     0.349   t为树的Gini:       0.222   alpha: 0.021

 name:  mean perimeter    val :  102.5

 t为根结点Gini :     0.346   t为树的Gini:       0.267   alpha: 0.079

 name:  worst smoothness          val :  0.1021

 t为根结点Gini :     0.262   t为树的Gini:       0.217   alpha: 0.011

 name:  mean radius       val :  18.08

 t为根结点Gini :     0.000   t为树的Gini:       0.000   alpha: 0.000

 name:  worst concavity           val :  0.1932

 t为根结点Gini :     0.265   t为树的Gini:       0.220   alpha: 0.022

 name:  mean radius       val :  16.02

 t为根结点Gini :     0.375   t为树的Gini:       0.333   alpha: 0.042

五  例子

     1: 乳腺癌例子

         

# -*- coding: utf-8 -*-
"""
Created on Mon Sep  9 16:49:01 2019

@author: chengxf2
"""

import numpy as np
import matplotlib.pyplot  as plt
import sys,os
from sklearn.datasets import load_breast_cancer

class CART:
    def IsTree(self, obj):
        
        bTree =(type(obj).__name__=='dict')
    
        return bTree
    
    """
    获得叶结点个数
    Args
        tree:    树
    return
        num
    """
    def  GetLeafs(self, tree):
        
        numLeaf = 0
        keys =['left', 'right']
        
        for key in keys:
            subTree = tree[key]
            
            if self.IsTree(subTree):
                numLeaf += self.GetLeafs(subTree)
            else:
                numLeaf +=1
       
        
        return numLeaf
        
        
    
    """
    获得树的深度
    Args
        tree:    树
    
    return 
        depth
    """
    def GetDepth(self, tree):
        
        print("====depth##########")
        
        maxDepth = 0
        keys =['left', 'right']
        
        for key in keys:
            subTree = tree[key]
            
            if self.IsTree(subTree):
                depth  =1+ self.GetDepth(subTree)
            else:
                depth = 1
            
            if depth >maxDepth:
                maxDepth = depth
       
        
        return maxDepth
        
        
            
        
        
        
        
    """
     加载数据集
     Args:
         None
    return 
      feature_names: 
     ['mean radius' 'mean texture' 'mean perimeter' 'mean area'
     'mean smoothness' 'mean compactness' 'mean concavity'
     'mean concave points' 'mean symmetry' 'mean fractal dimension'
     'radius error' 'texture error' 'perimeter error' 'area error'
     'smoothness error' 'compactness error' 'concavity error'
     'concave points error' 'symmetry error' 'fractal dimension error'
     'worst radius' 'worst texture' 'worst perimeter' 'worst area'
     'worst smoothness' 'worst compactness' 'worst concavity'
     'worst concave points' 'worst symmetry' 'worst fractal dimension']
    
        
     target_Name: 
       ['malignant' 恶性  'benign'  良性]
      ### 数据总长 569 ,300用来Train ,269 用来测试,剪枝
    """
    def LoadData(self):
    
        cancer = load_breast_cancer()
        data = cancer['data']
        target = cancer['target']
        target_Name = cancer['target_names']
        DESCR = cancer['DESCR']
        feature_names = cancer['feature_names']
        
    
        #print("\n data:  \n  ", data)
        #print("\n target:  \n  ", target)
        #print("\n target_Name:  \n  ", target_Name)
        #print("\n DESCR:  \n  ", DESCR)
        #print("\n feature_names:  \n  ", feature_names)
        
        self.trainData = data[0:300]   
        self.train_target = target[0:300]
        self.target_name= target_Name
        self.train_feature = feature_names.tolist()
     
        self.testData = data[300:-1]
        self.test_target= target[300:-1]
       
        

        
       
        

    
    """
    分割数据
    Args
        dataMat: 输入矩阵
        col: 列
        val: 值
    Returns
        L_Data: 小于该值的矩阵
        L_Labels: 左边标签
        R_Data: 右边矩阵
        R_Labels: 右边标签
    """
    def SplitData(self, dataList, col, val, Labels):
        L_Data =[]; L_Label=[]
        R_Data =[]; R_Label=[]
        
        m,n = np.shape(dataList)
       # print("\n m: ",m, "n: ",n)
        
        for i in range(m):
            cur = dataList[i][col]
            data = dataList[i]
            #print("\n data ",data)
            label = Labels[i]
            if cur<val:
                L_Data.append(data)
                L_Label.append(label)
            else:
                R_Data.append(data)
                R_Label.append(label)
        
       # print("\n =R_Mat= \n ",np.shape(R_Data))
        #print("======================\n")
        #print("\n  L_Data: \n ",np.shape(L_Data))
        return L_Data, L_Label, R_Data, R_Label
    
    
    """
    分割数据
    Args
        dataMat: 输入矩阵
        col: 列
        val: 值
    Returns
        L_Data: 小于该值的矩阵
        L_Labels: 左边标签
        R_Data: 右边矩阵
        R_Labels: 右边标签
    """
    def GetSubLabel(self, dataList, col, val, Labels):
        L_Label=[]
        R_Label=[]
        
        m,n = np.shape(dataList)
     
        
        for i in range(m):
            cur = dataList[i][col]

            label = Labels[i]
            if cur<val:

                L_Label.append(label)
            else:

                R_Label.append(label)

        return L_Label,R_Label
    
    
    def GetGini(self, Labels):
        
        m = len(Labels)   
        dictItem ={}
        gini = 0.0
        
        
        if m <1:
            return None
        
        
        for label in Labels:
            if label not in dictItem:
                dictItem[label]=0
            dictItem[label]=dictItem[label]+1
            
        for key in dictItem.keys():
            prob = np.power(dictItem[key]/m,2)
            gini +=prob
            
        
        return 1-gini
        
        
        
    
    """
    选择最优特征
    Args
        dataList:  数据集
        Labels:  标签集
    return 
        L_Data
        R_Data
        L_Label
        R_Label
    """
    def ChooseBestFeatures(self, dataList, Labels):
        
        m,n = np.shape(dataList)
        
        miniGini = float("inf")  ##Gini 系数选择最小的
        bestCol= 0  ##最佳特征
        bestFeature = 0  ##划分点
        
        for i in range(n):
            
            item = [data[i] for data in dataList]
            setFeature = set(item)
            
            for feature in setFeature:
                
               L_Label, R_Label= self.GetSubLabel(dataList, i, feature, Labels)
               
               m1 = len(L_Label)
               m2 = len(R_Label)
               
               if m1==0 or m2 ==0: continue
               
               gini = (m1*self.GetGini(L_Label)+m2*self.GetGini(R_Label))/m
               
               if gini<miniGini:
                   miniGini = gini
                   bestCol = i
                   bestFeature = feature
        print("\n minGini ",miniGini, "bestCol ",self.train_feature[bestCol], "\t bestFeature ",bestFeature)
        return  bestCol, bestFeature
                   
            
        
        
    """
    创建树
    Args
        dataList:    数据集
        Labels:      分类结果
    
    return 
       Tree
    """      
    def CreateTree(self, dataList, Labels):
        
      

        setLabel = set(Labels)
        m = len(dataList) ##样本个数
       
       
        if  1 == len(setLabel): ###只有一个分类
                return Labels[0]
        
        elif  1 ==m:  ##只有一个样本
                return Labels[0]
        
        feat, val = self.ChooseBestFeatures(dataList, Labels)
        
        tree ={}
        
        tree['feature']=val
        tree['name'] = self.train_feature[feat]
        L_Data, L_Label, R_Data, R_Label =self.SplitData(dataList, feat, val, Labels)
        tree['left']=self.CreateTree(L_Data,L_Label )
        tree['right']=self.CreateTree(R_Data, R_Label )
        
        return tree
    
    
    
    """
    分类
    Args
      myTree 决策树
      Labels 标签
      testVec 当前数据集
    
    return
      测试出来的标签
    """
    def Perdict(self, myTree, data):
        
        keyList= ['left','right']
        name = myTree['name']
        feature = myTree['feature']
        col = self.train_feature.index(name)
       
       

        cur = data[col]
        if cur<feature:
                subTree = myTree['left']
        else:
                subTree = myTree['right']
            
        if self.IsTree(subTree):
                label = self.Perdict(subTree, data)
        else:
                label =subTree


        return label
    
    """
    训练
    """
    def Train(self):
        

            tree = self.CreateTree(self.trainData, self.train_target)
            depth = self.GetDepth(tree)
            num = self.GetLeafs(tree)
            print("\n tree: \n ",tree)
            print("\n 叶结点数:    \t",num)
            print("\n  树深度:    \t ",depth )
            print("\n ***************树已生成*******************\n")
            num = len(self.testData)
            err = 0
            for i  in range(num):
               data = self.testData[i]
               true_label = self.test_target[i]
               label= self.Perdict(tree, data)
               print("\n label ",label, "\t  target ",true_label)
               
               if true_label != label:
                   err=err+1
            print("\n 测试样本: \t", num, "测试错误率: \t  ",err/num, "\t err: ",err)
    
       
    
    
    """
    初始化
    Args
        None
    return
        None
    """
    def __init__(self):
        

        self.LoadData()
        self.Train()
        
    

cart =CART()

    

         2: 剖腹产预测例子

       

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 28 13:59:01 2019

@author: chengxf2
"""
"""
决策树
"""
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm
import os,sys
import operator
import math
from decimal import *
import treePlot
from imp import reload 
import copy
import miniTree as mini
#import csvFun

"""
        数据集2:   根据学生情况,预测是否购买电脑
        self.dataList=[['teenager' ,'high', 'no' ,'same', 'no'],
               ['teenager', 'high', 'no', 'good', 'no'],
               ['middle_aged' ,'high', 'no', 'same', 'yes'],
               ['old_aged', 'middle', 'no' ,'same', 'yes'],
               ['old_aged', 'low', 'yes', 'same' ,'yes'],
               ['old_aged', 'low', 'yes', 'good', 'no'],
               ['middle_aged', 'low' ,'yes' ,'good', 'yes'],
               ['teenager' ,'middle' ,'no', 'same', 'no'],
               ['teenager', 'low' ,'yes' ,'same', 'yes'],
               ['old_aged' ,'middle', 'yes', 'same', 'yes'],
               ['teenager' ,'middle', 'yes', 'good', 'yes'],
               ['middle_aged' ,'middle', 'no', 'good', 'yes'],
               ['middle_aged', 'high', 'yes', 'same', 'yes'],
               ['old_aged', 'middle', 'no' ,'good' ,'no']]
        self.Labels=['age','input','student','level']
        RID	age	 income	 student	credit_rating 	class:buys_computer
"""
#
from csvFun import LoadFile

class Shannon:
    
    '''
    学生购买电脑例子
    Args
       None
    return 
       data
    '''
    def LoadDataStudent(self):
        self.dataList=[['teenager' ,'high', 'no' ,'same', 'no'],
                       ['teenager', 'high', 'no', 'good', 'no'],
                       ['middle_aged' ,'high', 'no', 'same', 'yes'],
                       ['old_aged', 'middle', 'no' ,'same', 'yes'],
                       ['old_aged', 'low', 'yes', 'same' ,'yes'],
                       ['old_aged', 'low', 'yes', 'good', 'no'],
                       ['middle_aged', 'low' ,'yes' ,'good', 'yes'],
                       ['teenager' ,'middle' ,'no', 'same', 'no'],
                       ['teenager', 'low' ,'yes' ,'same', 'yes'],
                       ['old_aged' ,'middle', 'yes', 'same', 'yes'],
                       ['teenager' ,'middle', 'yes', 'good', 'yes'],
                       ['middle_aged' ,'middle', 'no', 'good', 'yes'],
                       ['middle_aged', 'high', 'yes', 'same', 'yes'],
                       ['old_aged', 'middle', 'no' ,'good' ,'no']]
         
        self.dataLabel =['age','input','student','level','result']
        
    
    """
    降维
    Args
       age:连续数据
           
    return 
       离散数据
    """
    def ReduceDimension(self, age):
        
        #print("age ",min(age), max(age))
        
        DiAge =[]
        for x in age:
            
            if x<32:
                DiAge.append("Y")
          
          
            else:
                DiAge.append("M")
        
        return DiAge
    
    """
    转换为高斯分布
    
    Arg:
        age: list
        u: 平均值
    return
        复合高斯分布的数据
    """
    def GetInfo(self, dataList):
        
        
        age0 = []
        age1 = []
        
        for item in dataList:
            print("-1 ",item)
            if item[-1]=='0':
               age0.append(int(item[0]))
            else:
               age1.append(int(item[0]))
        
        
        #求方差
        mean0 = np.mean(age0)
        var0 = np.var(age0)
        mean1 = np.mean(age1)
        var1 = np.var(age1)
        #求标准差
        std0 = np.std(age0,ddof=1)
        std1 = np.std(age1,ddof=1)
        #print("mean0 ", mean0, " var0: ", var0, "\t mean1 ",mean1, "\t var1: ",var1)
        #print("\n std0: ",std0, "\t std1: ",std1)
        

    """
    如果属性集只有一个属性了,返回概率最高的那个
    Args
       数据
    return 
        分类结果
    """
    def  MajorCnt(self, classList):
        
        dictCount ={}
        
        for item  in classList:
           # print("\n vote ",vote)
            vote = item[-1]
            if vote not in dictCount:
                dictCount[vote] = 0
            dictCount[vote]+=1
           
       # print("\n classCount ",dictCount)
        
        sortKey = sorted(dictCount.items(), key =lambda d:d[1], reverse= True)
            
        return sortKey[0][0],dictCount
    
    """
    获得数据集路径,以及保存Tree路径
    arg:
        None
     return:
         File Path
    """
    def GetPath(self):
        
        ##
        fileName = "caesarian.csv"
        path  = os.path.abspath(fileName)
        self.path = path
        self.treePath = os.path.abspath("tree")
        
    """
    加载数据集
    Args:
        None
    return 
        dataList
    """
    def LoadData(self):
        
 
        #labels = ['no surfacing','flippers']  ###相当于属性
        #Caesarian 剖腹产
        dataList = LoadFile(self.path)
        self.dataList = dataList[1:]
        print("\n ****************************\n")
        #mini.MiniTree(0, self.dataList)
       
      
        age = [int(item[0]) for item in self.dataList]
        age = self.ReduceDimension(age)
        
        for i in range(len(self.dataList)):
            self.dataList[i][0]=age[i]

        self.dataLabel = ['Age','Num','time','Blood','Heart','Result']
      
     
      
      
    """
    计算单个熵
    Args
       prob
    return 
       H: 熵
    """
    def GetEnt(self, prob):
        
        ent = 0.0
        if prob==0 or prob ==1:
            
            return 0
        else:
            
            ent = -prob*np.log2(prob)
        return ent
    
    """
    获得对应的熵
    Args:
        dataList
    return 
        ent:熵
            
    """
    def GetHD(self, dataList):
        
        num = len(dataList)
       # print("样本个数: ",num)
        labelDict ={}
        
        for i in range(num):
            label = dataList[i][-1]
            
            if label not in labelDict:
                labelDict[label]=0
            labelDict[label]= labelDict[label]+1
        
        
        ent = 0.0
        
        for key in labelDict:
            
            prob = labelDict[key]/num
            ent += self.GetEnt(prob)
        #print("\n  ent::::  ",ent)
        return ent
        
            
        
    """
    调试
    """
    def Debug(self):
        
        n = np.arange(2,30)
        shanList =[]
        
        for i in range(2,30):
            
            p = 1/i
            
            y = self.GetShan(p)
            shanList.append(y*i)
        
        plt.plot(n, shanList)
        plt.show()
        
    """
    根据指定的属性,获取对应的属性
    Args
       col: 对应的一列
       attr: 属性
       
    return 
      匹配的数据集
    """
    def SpliteData(self,dataList, col, attr):
        

        subData = []
        
        for item in dataList:
            
            curAttr = item[col]
            #print("\n curAttr :",curAttr)
            
            data= []
            if curAttr == attr:
                data1 = item[:col]
                data2 = item[col+1:]
                data.extend(data1)
                data.extend(data2)
                #print("\n data1: ",data1,"\t data2: ",data2, "curAttr: ",curAttr)
                subData.append(data)
                
        n = len(subData)
        #print("n: ",n ,"\n subData ",subData)
        return subData
    
    """
    使用信息增益
    args
       dataList
    return 
       对应列

    """
    def ChooseAttr(self, dataList):
        
        numAttr = len(dataList[0])-1 ##最后一个是属性
        num = float(len(dataList))
        baseHD = self.GetHD(dataList)
        
        bestGain = 0.0
        bestAttr = 0 
        
        for col in range(numAttr):
            
            attrList = [data[col] for data in dataList]
            setAttr = set(attrList)
            
            HDa = 0.0
            
            HA =0.0
            for attr in setAttr:
                
                subData = self.SpliteData(dataList, col, attr)
                
                prob = len(subData)/num
                ent = self.GetHD(subData)
                HA += self.GetEnt(prob)
                HDa +=prob*ent
            gDA = (baseHD-HDa)/(HA+1.0)  ##除以HA 就是隐形增益比
            #print("HA ",HA)
            #print("infoGain ",infoGain)
            if (gDA> bestGain):
                bestGain = gDA
                bestAttr = col
        #print("\n ******bestAttr*********", bestAttr)
     
        return bestAttr,bestGain
                


   
    """
    创建树
    Args
       DataSet, labels
    
    return 
       treeDict
    """     
    def CreateTree(self, dataList, labels):
        
        kind = [item[-1]  for item in dataList] ##最后一列Result
        
     
      
               
        if kind.count(kind[0])== len(kind):  
            return kind[0]
        
        
        ###只剩下一个属性了
        if len(dataList[0]) ==1:  ###只有一个属性了
            label,dictCount = self.MajorCnt(dataList)
            return label
        
        bestAttr,bestGain = self.ChooseAttr(dataList) ##最佳一列
        if bestGain<self.epsilon :
             label,dictCount = self.MajorCnt(dataList)
             if label =='Y':
                print("***********error**********",dataList, "\t dict ",dictCount)
             return label
        #print("\n ",bestGain)
        bestLabel = labels[bestAttr]
        
        del labels[bestAttr] ##删除某一列
        myTree ={bestLabel:{}}
        
        ###创建分枝
        branch = [item[bestAttr] for item in dataList]
        setBranch = set(branch)
        
        for key in setBranch:
            
            subLabel = labels[:]
            subData = self.SpliteData(dataList, bestAttr, key)    
            subTree = self.CreateTree(subData, subLabel)
            myTree[bestLabel][key] = subTree
        
        return myTree
        
        
        
        
    """
    训练
    Args
       None
    
    return
      树
    """
    def Train(self):
    
          label = copy.deepcopy(self.dataLabel)
         
          
          tree = self.CreateTree(self.dataList, label) 
         
          #print("\n Train: dataLabel: ",self.dataLabel)
          
          self.storeTree(tree)
          self.grabTree()
         
          #print("\n tree ",tree)
          
          
    """
    获得叶节点数目
    
    Args:
         Tree
    return
         numLeaf
    """
    def GetNumLeaf(self, myTree):
        
        numLeaf = 0
        
        firstNode = list(myTree.keys())[0]
       
        #print("\n type::: ",type(firstNode))
        secondDict = myTree[firstNode]
       
        keys = list(secondDict.keys())
        #print("\n keys::: ",keys)
        
        for key in keys:
            if type(secondDict[key]).__name__=='dict':
                numLeaf += self.GetNumLeaf(secondDict[key])
            else:
                numLeaf+=1
        return numLeaf
            
            
    
    """
    获取树的深度
    Args
      Tree
    return
       Depth
    """
    def GetTreeDepth(self, myTree):
        
        maxDepth = 0
        firstNode = list(myTree.keys())[0]
        secondDict = myTree[firstNode]
        
        keys = list(secondDict.keys())
        
        for key in keys:
            
            if type(secondDict[key]).__name__=='dict':
                depth =1+ self.GetTreeDepth(secondDict[key])
            else:
                depth = 1
            
            if depth>maxDepth:
                maxDepth = depth
        return maxDepth
    
    
    """
    保存模型
    Args
       tree
       fileName
    return 
        None
    """
    def storeTree(self, tree):
        import pickle
        fileName = self.treePath
        fw = open(fileName,'wb')
        pickle.dump(tree, fw)
        #print("\n  storeTree label: ",self.dataLabel)
        fw.close()
    
    """
    加载树
    Args
     fileNmae
     return
      tree
    """
    def grabTree(self):
        import pickle
        fileName = self.treePath
        fr = open(fileName,'rb')
        
        tree = pickle.load(fr)
        leaf = self.GetNumLeaf(tree)
        print("leaf ",leaf)
        #print("\n tree: ",tree)
        #print("\n **********************\n ",tree)
        reload(treePlot)
        treePlot.createPlot(tree)
      
        
        errorNum = 0
        num = len(self.dataList)
        
        print("\n  grabTree label: ",self.dataLabel)
        for item in self.dataList:
            classifyLabel = self.classify(tree,  self.dataLabel, item)
            
            if classifyLabel != item[-1]:
                errorNum= errorNum+1
                #print("\n classifyLabel:", classifyLabel, " real: ",item[-1])
        
        print("\n errorNUm ",errorNum, " total: ",num)
        return tree
    
    
    """
    分类
    Args
      myTree 决策树
      Labels 标签
      testVec 当前数据集
    
    return
      测试出来的标签
    """
    def classify(self, myTree, Labels, testVec):
        
        father = list(myTree.keys())[0]    
   
        childDict = myTree[father]
        index = Labels.index(father)
        
    
        
        for key in list(childDict.keys()):
            
            if testVec[index]==key:
                if type(childDict[key]).__name__=='dict':
                    testLabel = self.classify(childDict[key], Labels, testVec)
                else:
                    testLabel = childDict[key]
        return testLabel      
        
    def __init__(self):
        
        self.epsilon  = 0.01
        self.m = 0
        self.n = 0
        self.fileName =""
        self.GetPath()
        self.LoadData()
        #self.LoadDataStudent()
        self.Train()
        
        
        
     
      


shannon = Shannon()
        
   
    

    
    
    

 六   问题

   1     有1000笔贷款,其中部分是10000以上,有部分是100以下,决策树准确率只有80%多,

      如何设计决策树,保证发放贷款是盈利的。

 

   2   剪枝主要是为了解决过拟合,但是当部分结点本身就欠拟合,或者精度不高,还需要剪枝?

 

  3    生成结点的终止条件

   

参考文档:

   

      《机器学习实战》

        《机器学习与应用》

       《统计学习方法》

        https://blog.csdn.net/tkkzc3E6s4Ou4/article/details/83829616

         https://www.cnblogs.com/paisenpython/p/10371644.html
          https://blog.csdn.net/hot7732788/article/details/90070618

         https://blog.csdn.net/ccblogger/article/details/82656185

         https://www.cnblogs.com/beiyan/p/8321329.html

          https://wenku.baidu.com/view/671df33631126edb6f1a101c.html

           https://www.cnblogs.com/lpworkstudyspace1992/p/8030186.html

https://cuijiahua.com/blog/2017/12/ml_13_regtree_1.html

 

   

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐