项目地址:http://github.com/jimersylee/MachineLearningAction
决策树(Decision Trees)
类似20个问题的游戏 参与游戏的一方在脑海里想某个事物,其他参与者向他提问题,只允许提20个问题,问题的答案只能用对或错回答,问问题的人通过推断分解,逐步缩小带猜测事物的范围.决策树的工作原理与20个问题类似,用户输入一系列数据,然后给出游戏的答案
决策树 优点:计算复杂度不高,输入结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据 缺点:可能产生过度匹配问题 适用数据类型:数值型与标称型
决策树的一般流程
- 收集数据:可以适用任何方法
- 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
- 分析数据:可以适用任何方法,构造书完成之后,我们应该检查图形是否符合预期
- 训练算法:构造树的数据结构
- 测试算法:适用经验树计算错误率
- 适用算法:此步骤可以适用于任何监督学习算法,而适用决策树可以更好地理解数据的内在含义
# coding:utf-8from mathimport log
import operator
import matplotlib.pyplotas plt
defcalcShannonEnt(dataSet):
"""
计算给定数据集的香农熵,香农熵代表数据的无序程度,香农熵越大,数据集越无序
:param dataSet: 数据集
:return:香农熵的值
"""
numEntries = len(dataSet)# 获得数据中实例的总数
labelCounts = {}# 声明一个字典# 为所有可能分类创建字典,它的键值是最后一列的数值for featVecin dataSet:
currentLabel = featVec[-1]# 目前的标签=数据集的每行的最后一个元素if currentLabelnotin labelCounts.keys():# 如果键值不存在,则扩展字典并将当前键值加入字典
labelCounts[currentLabel] = 0# 初始化为0
labelCounts[currentLabel] = labelCounts[currentLabel] + 1# 每个键值都记录了当前类别出现的次数,出现一次加1
shannonEnt = 0.0
for keyin labelCounts:
prob = float(labelCounts[key]) / numEntries# 使用所有类标签的发生频率计算类别出现的概率
shannonEnt = shannonEnt - prob * log(prob, 2)# 使用这个概率计算香农熵 以2为底求对数return shannonEnt
defcreateDataSet():
"""
创建简单鱼类鉴定数据集
:return:dataSet(数据集),labels(标签)
"""
_dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
_labels = ['no surfacing', 'flippers']
return _dataSet, _labels
myData, labels = createDataSet()
print myData
print calcShannonEnt(myData)
"""
result
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
0.970950594455
"""
# 熵越高,则混合的数据也越多,我们可以在数据集中添加更多的分类,观察熵是如何变化的,这里我们增加第三个名为maybe的分类,测试熵的变化
myData[0][-1] = 'maybe'# 第0行的最后一个元素置为maybeprint myData
print calcShannonEnt(myData)
"""
[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
1.37095059445
"""
# 得到熵以后,我们就可以按照获取最大信息增益的方法划分数据集# 另一个度量集合无需程度的方法是`基尼不纯度`(Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率"""
分类算法除了要测量信息熵,还需要划分数据集,度量花费数据集的熵,以便判断当前是否正确划分了数据集,我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照那个特征划分数据集是最好的划分方式
"""
defsplitDataSet(dataSet, axis, value):
"""
按照给定特征划分数据集
:param dataSet: 待划分的数据集
:param axis:划分数据集的特征位
:param value:预设特征值
:return:
"""
retDataSet = []# 创建新的list对象,因为函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期,因此需要创建一个新的列表for featVecin dataSet:
if featVec[axis] == value:# 如果axis位的值等于预设的特征值
reducedFeatVec = featVec[:axis]# 取行数据的前axis位
reducedFeatVec.extend(featVec[axis + 1:])# 加上axis位后的元素
retDataSet.append(reducedFeatVec)# 就是将除了axis特征位的元素抽取出来return retDataSet
# 测试splitDataSetprint "测试splitDataSet"
myData, labels = createDataSet()
print myData
newData = splitDataSet(myData, 0, 1)# 将待测试数据的第0位为1的数据的其他特征和标签筛选出来print newData
defchooseBestFeatureToSplit(dataSet):
"""
选择最好的数据集划分方式
:param dataSet: 数据集
:return:
"""
numFeatures = len(dataSet[0]) - 1
# 计算整个数据的原始香农熵
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
# 遍历数据集中的所有特征for iin range(numFeatures):
# 创建唯一的分类标签列表
featList = [example[i]for examplein dataSet]
uniqueVals = set(featList)# 将list转换为set,集合类型中每个值不同,从列表中创建集合是Python语言中得到列表中唯一元素值的最快方法
newEntropy = 0.0
# 计算每种划分方式的信息熵for valuein uniqueVals:# 遍历当前特征中的所以唯一属性值,对每个特征划分一次数据集
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)# 并对所有唯一特征值得到的熵求和
infoGian = baseEntropy - newEntropy# 信息增益=熵的减少值或者数据无序度的减少# 计算最好的信息增益if infoGian > bestInfoGain:
bestInfoGain = infoGian
bestFeature = i
return bestFeature# 返回最好特征划分的索引值# 获取最好特征划分索引值print "获取最好特征划分索引值"
myData, labels = createDataSet()
print "dataSet:" + str(myData)
print "best index:" + str(chooseBestFeatureToSplit(myData))
"""
代码表示,第0个特征是最好的用户划分数据集的特征
"""
"""
递归构建决策树
"""
defmajorityCnt(classList):
"""
类似classify0部分的投票表决
创建键值为classList中唯一值的数据字典,字典对象存储了classList中每个类标签出现的频率,最后利用operator操作键值排序字典,返回出现次数最多的分类名称
:rtype: str
:param classList: 分类名称的列表
:return: 返回出现次数最多的分类名称
"""
classCount = {}
for votein classList:
if votenotin classCount.keys():# 如果投票的类别不在classCount字典的key中,
classCount[vote] = 0# 则新建此key,value设置为0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
defcreateTree(dataSet, labels):
"""
创建树,使用了递归
:param dataSet:数据集
:param labels: 分类标签
:return: 树结构
"""
classList = [example[-1]for examplein dataSet]# 将数据集中的每一行的最后一个元素的值取出生成类别列表classListif classList.count(classList[0]) == len(classList):# 递归停止条件1return classList[0]# 类别完全相同时停止继续划分if len(dataSet[0]) == 1:# 递归停止条件2return majorityCnt(classList)# 遍历完所有特征时返回出现次数最多的
bestFeat = chooseBestFeatureToSplit(dataSet)# 选择最好的划分特征位
bestFeatLabel = labels[bestFeat]# 取出最好的标签
myTree = {bestFeatLabel: {}}# 初始化一个tree字典del (labels[bestFeat])# 删除最好的标签,防止树的下一个节点还使用此标签
featValues = [example[bestFeat]for examplein dataSet]# 创建最好的特征位的值所构建的list
uniqueVals = set(featValues)# 生成唯一的特征位的值构成的集合setfor valuein uniqueVals:
subLabels = labels[:]# 复制所有标签,且树不会和已经存在的标签搞混
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)# 递归创建树return myTree
# 构建树测试print "构建树测试"
myData, labels = createDataSet()
myTree = createTree(myData, labels)
print myTree
# coding:utf-8import matplotlib.pyplotas plt# 导入绘制图库"""
绘制树图
"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
defgetNumLeafs(myTree):
"""
获取叶节点的数目
:param myTree:树结构
:return: 叶节点的数目
"""
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for keyin secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict':# test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
defgetTreeDepth(myTree):
"""
获取树的层数(深度)
:param myTree:
:return: 树的层数(深度)
"""
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for keyin secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict':# 判断此节点是否是字典,如果不是就是叶节点
thisDepth = 1 + getTreeDepth(secondDict[key])# 是字典,层数+1else:
thisDepth = 1# 叶节点,层数为1if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
defplotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制节点
:param nodeTxt:节点名称
:param centerPt: 目前的位置
:param parentPt: 父节点的位置
:param nodeType: 节点类型
:return:
"""
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
defplotMidText(currentPt, parentPt, txtString):
"""
绘制中间的文本
:param currentPt:目前的位置
:param parentPt: 父节点的位置
:param txtString: 想绘制的文本
:return:
"""
xMid = (parentPt[0] - currentPt[0]) / 2.0 + currentPt[0]
yMid = (parentPt[1] - currentPt[1]) / 2.0 + currentPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
defplotTree(myTree, parentPt, nodeTxt):# if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree)# this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]# the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for keyin secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict':# test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key], cntrPt, str(key))# recursionelse:# it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# if you do get a dictionary you know it's a tree, and the first element will be another dictdefcreatePlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)# no ticks# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# def createPlot():# fig = plt.figure(1, facecolor='white')# fig.clf()# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)# plt.show()defretrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
# createPlot(thisTree)
myTree = retrieveTree(0)
createPlot(myTree)