决策树的简单实现

2016-10-11
喝牛奶的鸵鸟

我第一次接触这个字是在高中的化学课堂上,这里的是用来衡量一个系统混乱程度的度量。信息熵是信息论创始人香农提出的,它和上面提到的其实是正相关的。当我们面对一件非常不确定的事件时,我们需要大量的信息才能掌控此事件,也就是说信息量的大小和事件的不确定性有着直接的关系。 设X是一个有限状态的离散型随机变量,熵与概率之间的关系:

决策树

最近同学说食堂饭很难吃,我常去食堂,发现其实并不是经常性的难吃,只是偶尔有几次吃得想吐,由于宿舍楼里就这么一个食堂,大家都很懒,还是有很多人去排队吃食堂,下面以此作为一个例子简单实现一下,来简单预测一下明天的情况:

序号 天气 就餐人数
1 丰富
2
3 一般
4 丰富
5
6 一般

信息增益:得知特征A的信息而使得数据集D的信息的不确定性减少的程度

通过信息增益生成决策树: 计算一下Gain(天气):

同理算出 Gain(菜) = 0.252 bit,根据计算结果,Gain最大的作为起始开始生成决策树:

       天气
      /    \
     人多   菜
          / | \
      人多  人少 人少

代码实现:

from math import log
from collections import defaultdict
import json

def createDataSet():
    features = ['天气','菜']
    dataSet =[['雨','丰富','多'],
              ['雨','差','多'],
              ['雨','一般','多'],
              ['晴','丰富','多'],
              ['晴','差','少'],
              ['晴','一般','少']]
    return dataSet,features

def _entropy(dataSet):
    '''
    计算数据集的熵
        :param dataSet: 数据集
    '''
    dic = defaultdict(lambda: 0)
    for line in dataSet:
        dic[line[-1]] += 1
    ent = 0.0
    n = float(len(dataSet))
    for v in dic.values():
        p = v / n
        ent = ent - p * log(p,2)
    return ent

def _splitDataSet(dataSet,index,value):
    '''
    划分数据集
        :param dataSet: 数据集
        :param index: 特征索引
        :param value: 特征值
    '''
    subDataSet = []
    for line in dataSet:
        if line[index] == value:
            subDataSet.append(line[:index] + line[index+1:])
    return subDataSet       
    
def _gain(dataSet,index):
    '''
    计算信息增益
        :param dataSet: 数据集
        :param index: 特征索引
    '''
    n = float(len(dataSet))
    featureValueSet = set([line[index] for line in dataSet])
    subEnt = 0.0
    for value in featureValueSet:
        subDataSet = _splitDataSet(dataSet,index,value)
        p = len(subDataSet) / n
        subEnt = subEnt + p * _entropy(subDataSet)
    return _entropy(dataSet) - subEnt

def _bestFeatureIndex(dataSet,features):
    '''
    根据最大信息增益找到最好的特征
        :param dataSet: 数据集
        :param features: 全部特征
    '''
    maxGain , bestFeatureIndex = 0.0 , 0
    for i, _ in enumerate(features):
        g = _gain(dataSet,i)
        if g > maxGain:
            maxGain = g
            bestFeatureIndex = i
    return bestFeatureIndex

def createTree(dataSet,features):
    '''
    创建树
        :param dataSet: 数据集
        :param features: 全部特征 
    '''
    result = [line[-1] for line in dataSet]
    if len(set(result)) == 1:
        return result[0]
    i = _bestFeatureIndex(dataSet,features)
    bestFeature = features[i]
    tree = {
        bestFeature: {}
    }
    del(features[i])
    bestFeatureValueSet = set([line[i] for line in dataSet])
    for value in bestFeatureValueSet:
        subFeature = features[:]
        tree[bestFeature][value] = \
            createTree(_splitDataSet(dataSet,i,value),subFeature)
    return tree

def testID3(tree,feat,testValue):
    '''
    测试
        :param tree: 决策树
        :param feat: 特征
        :param testValue: 测试值 (e.g.,['晴','一般'])
    '''
    root = ''.join(tree.keys())
    nextDic = tree[root]
    featureIndex = 0
    for i,f in enumerate(feat):
        if f == root:
            featureIndex = i
    for key in nextDic.keys():
        if testValue[featureIndex] == key:
            if isinstance(nextDic[key], dict):
                return  testID3(nextDic[key],feat,testValue)
            else:
                return  nextDic[key]
 
dataSet,features = createDataSet()
feat = features[:]
tree = createTree(dataSet,features)
data = json.dumps(tree,ensure_ascii=False,indent=1)
print(data)
with open('data.json', 'w') as f:
    json.dump(data, f)
result = testID3(tree,feat,['晴','差'])
print(f'['','']---->{result}')