线性回归创建的预测模型需要拟合所有的样本点,在数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型太难,而且,生活中很多问题是非线性的,不可能使用全局线性模型来拟合任何数据。
一种可行的方法是把数据集切分成很多分易建模的数据,然后利用线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。这种切分方式下,树结构和回归法就相当有用。
CART算法:分类回归树,既可用于分类也可用于回归。
第三章使用的决策树构建算法是ID3,每次选取当前最佳的特征来分割数据。属于贪心算法,不考虑能否达到全局最优。而且容易造成过拟合、不能直接处理连续型特征,只有事先将连续型特征转换成离散型,才能使用ID3算法。
而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。如果特征值大于给定值就走左子树,小于给定值就走右子树。
CART算法的实现代码:
from numpy import *def loadDataSet(filename): dataMat=[] f=open(filename) for line in f.readlines(): curLine=line.strip().split('\t') floatLine=list(map(float,curLine)) dataMat.append(floatLine) return dataMatdef binSplitDataSet(dataSet,feature,value): mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:] mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:] return mat0,mat1def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)): feat, val = chooseBestSplit(dataSet, leafType, errType, ops) if feat == None: return val retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree
chooseBestSplit()函数暂未实现。
将CART算法用于回归:回归树假设叶子节点是常数值。用平方误差的总值(总方差)来计算连续型数值的混乱程度。总方差等于均方差乘以数据集中样本点的个数。
chooseBestSplit():给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。还要确定什么时候停止切分,一旦停止切分就会生成一个叶子节点。所以:用最佳方式切分数据集和生成相应的叶节点。
伪代码:
对每个特征: 对每个特征值: 将数据集切分成两份 计算切分后的误差 如果当前误差小于当前最小误差,将当前切分设定为最佳切分并更新最小误差 返回最佳切分的特征和阈值
切分函数的实现:
def regLeaf(dataSet): #负责生成叶节点,当chooseBestSplit函数确定不再对数据进行切分时,将调用regLeaf函数得到叶节点的模型 return mean(dataSet[:,-1]) #在回归树中,此模型就是目标变量的均值def regErr(dataSet): # 误差估计函数,计算目标变量的平方误差,需要返回总误差,即为均方误差乘以数据集中样本个数 return var(dataSet[:, -1]) * shape(dataSet)[0]def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): #ops为用户指定的参数,用于控制函数的停止时机 tolS = ops[0] # 容许的误差下降值 tolN = ops[1] # 切分的最少样本数 if len(set(dataSet[:, -1].T.tolist()[0])) == 1: # 统计不同剩余特征值得数目,如果数目为一,就不需要再切分而直接返回 return None, leafType(dataSet) else: m, n = shape(dataSet) S = errType(dataSet) #误差 bestS = inf #最小误差 bestIndex = 0 bestValue = 0 for featIndex in range(n - 1): # 对所有特征进行遍历,找到最佳切分方式。最佳切分就是使得切分后能达到最低误差的切分 # for splitVal in set(dataSet[:, featIndex]): # 遍历某个特征的所有特征值 for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]): mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 按照某个特征的某个值将数据切分成两个数据子集 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 如果某个子集行数不大于tolN,也不应该切分 continue newS = errType(mat0) + errType(mat1) # 新误差由切分后的两个数据子集组成的误差 if newS < bestS: # 判断新切分能否降低误差 bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) < tolS: # 如果误差降低不大则退出 return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 如果切分出的数据集很小则退出 return None, leafType(dataSet) return bestIndex, bestValue
regLeaf():负责生成叶节点,即求当前数据集目标值的平均值作为回归预测值。当chooseBestSplit()确定不再对数据进行切分时,将调用regLeaf()函数来得到叶节点的模型。回归树中,该模型是目标变量的均值。
regErr():误差估计函数。在给定数据集上计算目标变量的平方误差。
chooseBestSplit():构建回归树的核心函数。目的是找到数据的最佳二元切分方式。如果找不到好的二元切分,就返回None并同时调用regLeaf()方法来产生叶节点。
运行代码:
if __name__=='__main__': myMat=loadDataSet('ex00.txt') myMat=mat(myMat) result=createTree(myMat) print(result)
输出为:
{'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572, 'left': 1.0180967672413792}
只有两个叶节点,对照下面的散点图可以看出,在数据0.48813左侧的数据,回归预测值为-0.04465,右侧预测值为1.018。
数据集散点图:
因为数据集简单,所以得到的回归树也简单。
更换数据集测试:
if __name__=='__main__': myMat2=loadDataSet('ex2.txt') myMat2=mat(myMat2) myTree = createTree(myMat2, ops=(0, 1)) print(myTree)
输出:
{'spInd': 0, 'spVal': 0.499171, 'right': {'spInd': 0, 'spVal': 0.457563, 'right': {'spInd': 0, 'spVal': 0.455761, 'right': {'spInd': 0, 'spVal': 0.126833, 'right': {'spInd': 0, 'spVal': 0.124723, 'right': {'spInd': 0, 'spVal': 0.085111, 'right': {'spInd': 0, 'spVal': 0.084661, 'right': {'spInd': 0, 'spVal': 0.080061, 'right': {'spInd': 0, 'spVal': 0.068373, 'right': {'spInd': 0, 'spVal': 0.061219, 'right': {'spInd': 0, 'spVal': 0.044737, 'right': {'spInd': 0, 'spVal': 0.028546, 'right': {'spInd': 0, 'spVal': 0.000256, 'right': 9.668106, 'left': -8.377094}, 'left': {'spInd': 0, 'spVal': 0.039914, 'right': 11.220099, 'left': 3.855393}}, 'left': {'spInd': 0, 'spVal': 0.053764, 'right': -13.731698, 'left': {'spInd': 0, 'spVal': 0.055862, 'right': -3.131497, 'left': 6.695567}}}, 'left': -15.160836}, 'left': {'spInd': 0, 'spVal': 0.079632, 'right': 29.420068, 'left': 2.229873}}, 'left': -24.132226}, 'left': 37.820659}, 'left': {'spInd': 0, 'spVal': 0.108801, 'right': {'spInd': 0, 'spVal': 0.10796, 'right': {'spInd': 0, 'spVal': 0.085873, 'right': -10.137104, 'left': -1.293195}, 'left': -16.106164}, 'left': {'spInd': 0, 'spVal': 0.11515, 'right': 13.795828, 'left': -1.402796}}}, 'left': 22.891675}, 'left': {'spInd': 0, 'spVal': 0.130626, 'right': -39.524461, 'left': {'spInd': 0, 'spVal': 0.382037, 'right': {'spInd': 0, 'spVal': 0.335182, 'right': {'spInd': 0, 'spVal': 0.324274, 'right': {'spInd': 0, 'spVal': 0.309133, 'right': {'spInd': 0, 'spVal': 0.131833, 'right': 22.478291, 'left': {'spInd': 0, 'spVal': 0.138619, 'right': -29.087463, 'left': {'spInd': 0, 'spVal': 0.156067, 'right': {'spInd': 0, 'spVal': 0.13988, 'right': 7.336784, 'left': 7.557349}, 'left': {'spInd': 0, 'spVal': 0.166765, 'right': {'spInd': 0, 'spVal': 0.156273, 'right': 0.225886, 'left': {'spInd': 0, 'spVal': 0.164134, 'right': -27.405211, 'left': {'spInd': 0, 'spVal': 0.166431, 'right': -6.512506, 'left': -14.740059}}}, 'left': {'spInd': 0, 'spVal': 0.193282, 'right': {'spInd': 0, 'spVal': 0.176523, 'right': 0.946348, 'left': 18.208423}, 'left': {'spInd': 0, 'spVal': 0.211633, 'right': {'spInd': 0, 'spVal': 0.202161, 'right': {'spInd': 0, 'spVal': 0.199903, 'right': -3.372472, 'left': -1.983889}, 'left': {'spInd': 0, 'spVal': 0.203993, 'right': -22.379119, 'left': {'spInd': 0, 'spVal': 0.206207, 'right': -12.619036, 'left': -8.332207}}}, 'left': {'spInd': 0, 'spVal': 0.228473, 'right': {'spInd': 0, 'spVal': 0.222271, 'right': {'spInd': 0, 'spVal': 0.218321, 'right': {'spInd': 0, 'spVal': 0.217214, 'right': -3.958752, 'left': 1.410768}, 'left': -9.255852}, 'left': {'spInd': 0, 'spVal': 0.2232, 'right': 15.501642, 'left': 19.425158}}, 'left': {'spInd': 0, 'spVal': 0.25807, 'right': {'spInd': 0, 'spVal': 0.228628, 'right': -2.266273, 'left': {'spInd': 0, 'spVal': 0.228751, 'right': -30.812912, 'left': {'spInd': 0, 'spVal': 0.232802, 'right': 1.222318, 'left': -20.425137}}}, 'left': {'spInd': 0, 'spVal': 0.284794, 'right': {'spInd': 0, 'spVal': 0.273863, 'right': {'spInd': 0, 'spVal': 0.264926, 'right': {'spInd': 0, 'spVal': 0.264639, 'right': 2.557923, 'left': 5.280579}, 'left': -9.457556}, 'left': 35.623746}, 'left': {'spInd': 0, 'spVal': 0.300318, 'right': {'spInd': 0, 'spVal': 0.297107, 'right': {'spInd': 0, 'spVal': 0.295993, 'right': {'spInd': 0, 'spVal': 0.290749, 'right': -14.391613, 'left': -14.988279}, 'left': -1.798377}, 'left': -18.051318}, 'left': 8.814725}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.310956, 'right': -49.939516, 'left': {'spInd': 0, 'spVal': 0.318309, 'right': -27.605424, 'left': -13.189243}}}, 'left': {'spInd': 0, 'spVal': 0.32889, 'right': 39.783113, 'left': {'spInd': 0, 'spVal': 0.331364, 'right': -1.290825, 'left': {'spInd': 0, 'spVal': 0.3349, 'right': 18.97665, 'left': 2.768225}}}}, 'left': {'spInd': 0, 'spVal': 0.370042, 'right': {'spInd': 0, 'spVal': 0.35679, 'right': {'spInd': 0, 'spVal': 0.350725, 'right': {'spInd': 0, 'spVal': 0.350065, 'right': {'spInd': 0, 'spVal': 0.342761, 'right': {'spInd': 0, 'spVal': 0.342155, 'right': {'spInd': 0, 'spVal': 0.3417, 'right': -23.547711, 'left': -16.930416}, 'left': -31.584855}, 'left': -1.319852}, 'left': -40.086564}, 'left': {'spInd': 0, 'spVal': 0.351478, 'right': -0.461116, 'left': -19.526539}}, 'left': -32.124495}, 'left': {'spInd': 0, 'spVal': 0.378965, 'right': {'spInd': 0, 'spVal': 0.373501, 'right': -8.228297, 'left': {'spInd': 0, 'spVal': 0.377383, 'right': 5.241196, 'left': 13.583555}}, 'left': -29.007783}}}, 'left': {'spInd': 0, 'spVal': 0.388789, 'right': {'spInd': 0, 'spVal': 0.385021, 'right': 24.816941, 'left': 21.578007}, 'left': {'spInd': 0, 'spVal': 0.437652, 'right': {'spInd': 0, 'spVal': 0.412516, 'right': {'spInd': 0, 'spVal': 0.403228, 'right': {'spInd': 0, 'spVal': 0.391609, 'right': 3.001104, 'left': -1.729244}, 'left': -26.419289}, 'left': {'spInd': 0, 'spVal': 0.418943, 'right': 44.161493, 'left': {'spInd': 0, 'spVal': 0.426711, 'right': -21.594268, 'left': {'spInd': 0, 'spVal': 0.428582, 'right': 15.224266, 'left': 19.745224}}}}, 'left': {'spInd': 0, 'spVal': 0.454312, 'right': {'spInd': 0, 'spVal': 0.446196, 'right': -5.108172, 'left': {'spInd': 0, 'spVal': 0.451087, 'right': -28.724685, 'left': -20.360067}}, 'left': {'spInd': 0, 'spVal': 0.454375, 'right': 3.043912, 'left': 9.841938}}}}}}}, 'left': -34.044555}, 'left': {'spInd': 0, 'spVal': 0.465561, 'right': {'spInd': 0, 'spVal': 0.463241, 'right': 17.171057, 'left': 30.051931}, 'left': {'spInd': 0, 'spVal': 0.467383, 'right': {'spInd': 0, 'spVal': 0.46568, 'right': -23.777531, 'left': -9.712925}, 'left': {'spInd': 0, 'spVal': 0.483803, 'right': 5.224234, 'left': {'spInd': 0, 'spVal': 0.487381, 'right': 27.729263, 'left': {'spInd': 0, 'spVal': 0.487537, 'right': 5.149336, 'left': 11.924204}}}}}}, 'left': {'spInd': 0, 'spVal': 0.729397, 'right': {'spInd': 0, 'spVal': 0.640515, 'right': {'spInd': 0, 'spVal': 0.613004, 'right': {'spInd': 0, 'spVal': 0.606417, 'right': {'spInd': 0, 'spVal': 0.513332, 'right': {'spInd': 0, 'spVal': 0.508548, 'right': {'spInd': 0, 'spVal': 0.508542, 'right': 96.403373, 'left': 93.292829}, 'left': 101.075609}, 'left': {'spInd': 0, 'spVal': 0.533511, 'right': {'spInd': 0, 'spVal': 0.51915, 'right': 116.176162, 'left': {'spInd': 0, 'spVal': 0.531944, 'right': 124.795495, 'left': 129.766743}}, 'left': {'spInd': 0, 'spVal': 0.548539, 'right': {'spInd': 0, 'spVal': 0.546601, 'right': {'spInd': 0, 'spVal': 0.537834, 'right': 90.995536, 'left': {'spInd': 0, 'spVal': 0.543843, 'right': 98.36201, 'left': 96.319043}}, 'left': 83.114502}, 'left': {'spInd': 0, 'spVal': 0.553797, 'right': {'spInd': 0, 'spVal': 0.549814, 'right': 137.267576, 'left': 120.857321}, 'left': {'spInd': 0, 'spVal': 0.560301, 'right': 82.903945, 'left': {'spInd': 0, 'spVal': 0.599142, 'right': {'spInd': 0, 'spVal': 0.589806, 'right': {'spInd': 0, 'spVal': 0.582311, 'right': {'spInd': 0, 'spVal': 0.571214, 'right': {'spInd': 0, 'spVal': 0.569327, 'right': 108.435392, 'left': 114.872056}, 'left': 82.589328}, 'left': {'spInd': 0, 'spVal': 0.585413, 'right': 125.295113, 'left': 98.674874}}, 'left': 130.378529}, 'left': 93.521396}}}}}}, 'left': 168.180746}, 'left': {'spInd': 0, 'spVal': 0.623909, 'right': {'spInd': 0, 'spVal': 0.618868, 'right': 76.917665, 'left': 87.181863}, 'left': {'spInd': 0, 'spVal': 0.628061, 'right': {'spInd': 0, 'spVal': 0.624827, 'right': 105.970743, 'left': 117.628346}, 'left': {'spInd': 0, 'spVal': 0.637999, 'right': {'spInd': 0, 'spVal': 0.632691, 'right': 93.645293, 'left': 91.656617}, 'left': 82.713621}}}}, 'left': {'spInd': 0, 'spVal': 0.642373, 'right': 140.613941, 'left': {'spInd': 0, 'spVal': 0.642707, 'right': 82.500766, 'left': {'spInd': 0, 'spVal': 0.665329, 'right': {'spInd': 0, 'spVal': 0.661073, 'right': {'spInd': 0, 'spVal': 0.652462, 'right': 112.715799, 'left': 115.687524}, 'left': 121.980607}, 'left': {'spInd': 0, 'spVal': 0.706961, 'right': {'spInd': 0, 'spVal': 0.698472, 'right': {'spInd': 0, 'spVal': 0.689099, 'right': {'spInd': 0, 'spVal': 0.666452, 'right': {'spInd': 0, 'spVal': 0.665652, 'right': 105.547997, 'left': 120.014736}, 'left': {'spInd': 0, 'spVal': 0.667851, 'right': 92.449664, 'left': {'spInd': 0, 'spVal': 0.680486, 'right': 110.367074, 'left': 112.378209}}}, 'left': 120.521925}, 'left': {'spInd': 0, 'spVal': 0.69892, 'right': 92.470636, 'left': {'spInd': 0, 'spVal': 0.699873, 'right': 115.586605, 'left': {'spInd': 0, 'spVal': 0.70639, 'right': 105.062147, 'left': 106.180427}}}}, 'left': {'spInd': 0, 'spVal': 0.70889, 'right': 135.416767, 'left': {'spInd': 0, 'spVal': 0.716211, 'right': {'spInd': 0, 'spVal': 0.710234, 'right': 108.553919, 'left': 103.345308}, 'left': 110.90283}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.952833, 'right': {'spInd': 0, 'spVal': 0.759504, 'right': {'spInd': 0, 'spVal': 0.740859, 'right': {'spInd': 0, 'spVal': 0.731636, 'right': 73.912028, 'left': 93.773929}, 'left': {'spInd': 0, 'spVal': 0.757527, 'right': 63.549854, 'left': 81.106762}}, 'left': {'spInd': 0, 'spVal': 0.763328, 'right': 115.199195, 'left': {'spInd': 0, 'spVal': 0.769043, 'right': 64.041941, 'left': {'spInd': 0, 'spVal': 0.790312, 'right': {'spInd': 0, 'spVal': 0.786865, 'right': {'spInd': 0, 'spVal': 0.785574, 'right': {'spInd': 0, 'spVal': 0.777582, 'right': 100.838446, 'left': 107.024467}, 'left': 100.598825}, 'left': {'spInd': 0, 'spVal': 0.787755, 'right': 118.642009, 'left': 110.15973}}, 'left': {'spInd': 0, 'spVal': 0.806158, 'right': {'spInd': 0, 'spVal': 0.799873, 'right': {'spInd': 0, 'spVal': 0.798198, 'right': 76.853728, 'left': 91.368473}, 'left': 62.877698}, 'left': {'spInd': 0, 'spVal': 0.815215, 'right': {'spInd': 0, 'spVal': 0.811602, 'right': {'spInd': 0, 'spVal': 0.811363, 'right': 112.981216, 'left': 99.841379}, 'left': 118.319942}, 'left': {'spInd': 0, 'spVal': 0.833026, 'right': {'spInd': 0, 'spVal': 0.823848, 'right': {'spInd': 0, 'spVal': 0.819722, 'right': 70.054508, 'left': 59.342323}, 'left': 76.723835}, 'left': {'spInd': 0, 'spVal': 0.841547, 'right': {'spInd': 0, 'spVal': 0.838587, 'right': 134.089674, 'left': 115.669032}, 'left': {'spInd': 0, 'spVal': 0.841625, 'right': 60.552308, 'left': {'spInd': 0, 'spVal': 0.944221, 'right': {'spInd': 0, 'spVal': 0.85497, 'right': {'spInd': 0, 'spVal': 0.84294, 'right': 95.893131, 'left': {'spInd': 0, 'spVal': 0.847219, 'right': 76.240984, 'left': 89.20993}}, 'left': {'spInd': 0, 'spVal': 0.936524, 'right': {'spInd': 0, 'spVal': 0.934853, 'right': {'spInd': 0, 'spVal': 0.925782, 'right': {'spInd': 0, 'spVal': 0.910975, 'right': {'spInd': 0, 'spVal': 0.901444, 'right': {'spInd': 0, 'spVal': 0.901421, 'right': {'spInd': 0, 'spVal': 0.892999, 'right': {'spInd': 0, 'spVal': 0.888426, 'right': {'spInd': 0, 'spVal': 0.872199, 'right': {'spInd': 0, 'spVal': 0.866451, 'right': {'spInd': 0, 'spVal': 0.856421, 'right': 107.166848, 'left': 94.402102}, 'left': 111.552716}, 'left': {'spInd': 0, 'spVal': 0.883615, 'right': {'spInd': 0, 'spVal': 0.872883, 'right': 95.887712, 'left': 95.348184}, 'left': {'spInd': 0, 'spVal': 0.885676, 'right': 108.045948, 'left': 94.896354}}}, 'left': 82.436686}, 'left': {'spInd': 0, 'spVal': 0.900699, 'right': {'spInd': 0, 'spVal': 0.896683, 'right': 107.00162, 'left': 109.188248}, 'left': 100.133819}}, 'left': 87.300625}, 'left': {'spInd': 0, 'spVal': 0.908629, 'right': 118.513475, 'left': 106.814667}}, 'left': {'spInd': 0, 'spVal': 0.912161, 'right': 85.005351, 'left': {'spInd': 0, 'spVal': 0.915263, 'right': 96.71761, 'left': 92.074619}}}, 'left': 115.753994}, 'left': 65.548418}, 'left': {'spInd': 0, 'spVal': 0.937766, 'right': 119.949824, 'left': 100.120253}}}, 'left': {'spInd': 0, 'spVal': 0.948822, 'right': 69.318649, 'left': {'spInd': 0, 'spVal': 0.949198, 'right': 105.752508, 'left': {'spInd': 0, 'spVal': 0.952377, 'right': 73.520802, 'left': 100.649591}}}}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.965969, 'right': {'spInd': 0, 'spVal': 0.956951, 'right': {'spInd': 0, 'spVal': 0.953902, 'right': 130.92648, 'left': {'spInd': 0, 'spVal': 0.954711, 'right': 100.935789, 'left': 82.016541}}, 'left': {'spInd': 0, 'spVal': 0.958512, 'right': 135.837013, 'left': {'spInd': 0, 'spVal': 0.960398, 'right': 123.559747, 'left': 112.386764}}}, 'left': {'spInd': 0, 'spVal': 0.968621, 'right': 98.648346, 'left': 86.399637}}}}}
散点图:
得到的树很复杂,改变ops元组的值:
if __name__=='__main__': myMat2 = loadDataSet('ex2.txt') myMat2 = mat(myMat2) myTree = createTree(myMat2, ops=(10000, 4)) print(myTree)
输出:
{'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}
也可以得到仅有两个叶节点的树。
树剪枝:一棵树如果节点过多,表明该模型可能对数据进行了过拟合。通过降低决策树的复杂度来避免过拟合的过程称为“剪枝”。
在函数chooseBestSplit()中的提前终止条件,实际上是“预剪枝”操作,预剪枝操作对于参数ops元组非常敏感,难以获得有效的回归树。
后剪枝:利用测试集对数进行剪枝。由于不需要用户指定参数,后剪枝是一种更理想化的剪枝方法。
首先将数据集划分为训练集和测试集。先使用训练集构建出一棵足够复杂的树便于剪枝。然后从上到下找到叶节点,用测试集来判断这些叶节点合并能不能降低测试误差,如果可以的话就合并。
伪代码如下:
基于已有的树切分测试数据: 如果存在任一子集是一棵树,则在该子集递归剪枝过程 计算将当前两个叶子节点合并后的误差 计算不合并的误差 如果合并会降低误差则合并
回归树剪枝函数prune():
def isTree(obj): # 测试输入变量是否是一棵树,返回布尔型的结果,用于判断当前处理的节点是否是叶节点 return (type(obj).__name__ == "dict")def getMean(tree): # 递归函数,从上到下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理 if isTree(tree["right"]): tree["right"] = getMean(tree["right"]) if isTree(tree["left"]): tree["left"] = getMean(tree["left"]) return (tree["left"] + tree["right"]) / 2.0def prune(tree, testData): #参数:待剪枝的树与剪枝所需的测试数据 if shape(testData)[0] == 0: #没有测试数据则对树进行塌陷处理 return getMean(tree) if (isTree(tree['right']) or isTree(tree['left'])): # lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2)) treeMean = (tree['left'] + tree['right']) / 2.0 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) if errorMerge < errorNoMerge: print("融合") return treeMean else: return tree else: return tree
isTree():测试输入变量是否是一棵树,返回布尔值的结果。用于判断当前处理的节点是不是叶子节点。
getMean():递归函数,从上到下遍历树直到叶节点。如果找到两个叶节点就返回其平均值。该函数对树进行塌陷处理。
prune():参数为待剪枝的树和剪枝所需的测试数据集。
测试:
if __name__=='__main__': myMat2=loadDataSet('ex2.txt') myMat2=mat(myMat2) myTree = createTree(myMat2, ops=(0, 1)) myDat2Test = loadDataSet("ex2test.txt") myMat2Test = mat(myDat2Test) result=prune(myTree, myMat2Test) print(result)
输出:
融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合融合{ 'left': { 'left': { 'left': { 'left': 92.5239915, 'spInd': 0, 'spVal': 0.965969, 'right': { 'left': { 'left': { 'left': 112.386764, 'spInd': 0, 'spVal': 0.960398, 'right': 123.559747}, 'spInd': 0, 'spVal': 0.958512, 'right': 135.837013}, 'spInd': 0, 'spVal': 0.956951, 'right': 111.2013225}}, 'spInd': 0, 'spVal': 0.952833, 'right': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': 96.41885225, 'spInd': 0, 'spVal': 0.948822, 'right': 69.318649}, 'spInd': 0, 'spVal': 0.944221, 'right': { 'left': { 'left': 110.03503850000001, 'spInd': 0, 'spVal': 0.936524, 'right': { 'left': 65.548418, 'spInd': 0, 'spVal': 0.934853, 'right': { 'left': 115.753994, 'spInd': 0, 'spVal': 0.925782, 'right': { 'left': { 'left': 94.3961145, 'spInd': 0, 'spVal': 0.912161, 'right': 85.005351}, 'spInd': 0, 'spVal': 0.910975, 'right': { 'left': { 'left': 106.814667, 'spInd': 0, 'spVal': 0.908629, 'right': 118.513475}, 'spInd': 0, 'spVal': 0.901444, 'right': { 'left': 87.300625, 'spInd': 0, 'spVal': 0.901421, 'right': { 'left': { 'left': 100.133819, 'spInd': 0, 'spVal': 0.900699, 'right': 108.094934}, 'spInd': 0, 'spVal': 0.892999, 'right': { 'left': 82.436686, 'spInd': 0, 'spVal': 0.888426, 'right': { 'left': 98.54454949999999, 'spInd': 0, 'spVal': 0.872199, 'right': 106.16859550000001}}}}}}}}}, 'spInd': 0, 'spVal': 0.85497, 'right': { 'left': { 'left': 89.20993, 'spInd': 0, 'spVal': 0.847219, 'right': 76.240984}, 'spInd': 0, 'spVal': 0.84294, 'right': 95.893131}}}, 'spInd': 0, 'spVal': 0.841625, 'right': 60.552308}, 'spInd': 0, 'spVal': 0.841547, 'right': 124.87935300000001}, 'spInd': 0, 'spVal': 0.833026, 'right': { 'left': 76.723835, 'spInd': 0, 'spVal': 0.823848, 'right': { 'left': 59.342323, 'spInd': 0, 'spVal': 0.819722, 'right': 70.054508}}}, 'spInd': 0, 'spVal': 0.815215, 'right': { 'left': 118.319942, 'spInd': 0, 'spVal': 0.811602, 'right': { 'left': 99.841379, 'spInd': 0, 'spVal': 0.811363, 'right': 112.981216}}}, 'spInd': 0, 'spVal': 0.806158, 'right': 73.49439925}, 'spInd': 0, 'spVal': 0.790312, 'right': { 'left': 114.4008695, 'spInd': 0, 'spVal': 0.786865, 'right': 102.26514075}}, 'spInd': 0, 'spVal': 0.769043, 'right': 64.041941}, 'spInd': 0, 'spVal': 0.763328, 'right': 115.199195}, 'spInd': 0, 'spVal': 0.759504, 'right': 78.08564325}}, 'spInd': 0, 'spVal': 0.729397, 'right': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': 110.90283, 'spInd': 0, 'spVal': 0.716211, 'right': { 'left': 103.345308, 'spInd': 0, 'spVal': 0.710234, 'right': 108.553919}}, 'spInd': 0, 'spVal': 0.70889, 'right': 135.416767}, 'spInd': 0, 'spVal': 0.706961, 'right': { 'left': { 'left': { 'left': { 'left': 106.180427, 'spInd': 0, 'spVal': 0.70639, 'right': 105.062147}, 'spInd': 0, 'spVal': 0.699873, 'right': 115.586605}, 'spInd': 0, 'spVal': 0.69892, 'right': 92.470636}, 'spInd': 0, 'spVal': 0.698472, 'right': { 'left': 120.521925, 'spInd': 0, 'spVal': 0.689099, 'right': { 'left': 101.91115275, 'spInd': 0, 'spVal': 0.666452, 'right': 112.78136649999999}}}}, 'spInd': 0, 'spVal': 0.665329, 'right': { 'left': 121.980607, 'spInd': 0, 'spVal': 0.661073, 'right': { 'left': 115.687524, 'spInd': 0, 'spVal': 0.652462, 'right': 112.715799}}}, 'spInd': 0, 'spVal': 0.642707, 'right': 82.500766}, 'spInd': 0, 'spVal': 0.642373, 'right': 140.613941}, 'spInd': 0, 'spVal': 0.640515, 'right': { 'left': { 'left': { 'left': { 'left': 82.713621, 'spInd': 0, 'spVal': 0.637999, 'right': { 'left': 91.656617, 'spInd': 0, 'spVal': 0.632691, 'right': 93.645293}}, 'spInd': 0, 'spVal': 0.628061, 'right': { 'left': 117.628346, 'spInd': 0, 'spVal': 0.624827, 'right': 105.970743}}, 'spInd': 0, 'spVal': 0.623909, 'right': 82.04976400000001}, 'spInd': 0, 'spVal': 0.613004, 'right': { 'left': 168.180746, 'spInd': 0, 'spVal': 0.606417, 'right': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': 93.521396, 'spInd': 0, 'spVal': 0.599142, 'right': { 'left': 130.378529, 'spInd': 0, 'spVal': 0.589806, 'right': { 'left': 111.9849935, 'spInd': 0, 'spVal': 0.582311, 'right': { 'left': 82.589328, 'spInd': 0, 'spVal': 0.571214, 'right': { 'left': 114.872056, 'spInd': 0, 'spVal': 0.569327, 'right': 108.435392}}}}}, 'spInd': 0, 'spVal': 0.560301, 'right': 82.903945}, 'spInd': 0, 'spVal': 0.553797, 'right': 129.0624485}, 'spInd': 0, 'spVal': 0.548539, 'right': { 'left': 83.114502, 'spInd': 0, 'spVal': 0.546601, 'right': { 'left': 97.3405265, 'spInd': 0, 'spVal': 0.537834, 'right': 90.995536}}}, 'spInd': 0, 'spVal': 0.533511, 'right': { 'left': { 'left': 129.766743, 'spInd': 0, 'spVal': 0.531944, 'right': 124.795495}, 'spInd': 0, 'spVal': 0.51915, 'right': 116.176162}}, 'spInd': 0, 'spVal': 0.513332, 'right': { 'left': 101.075609, 'spInd': 0, 'spVal': 0.508548, 'right': { 'left': 93.292829, 'spInd': 0, 'spVal': 0.508542, 'right': 96.403373}}}}}}}, 'spInd': 0, 'spVal': 0.499171, 'right': { 'left': { 'left': { 'left': { 'left': { 'left': 8.53677, 'spInd': 0, 'spVal': 0.487381, 'right': 27.729263}, 'spInd': 0, 'spVal': 0.483803, 'right': 5.224234}, 'spInd': 0, 'spVal': 0.467383, 'right': { 'left': -9.712925, 'spInd': 0, 'spVal': 0.46568, 'right': -23.777531}}, 'spInd': 0, 'spVal': 0.465561, 'right': { 'left': 30.051931, 'spInd': 0, 'spVal': 0.463241, 'right': 17.171057}}, 'spInd': 0, 'spVal': 0.457563, 'right': { 'left': -34.044555, 'spInd': 0, 'spVal': 0.455761, 'right': { 'left': { 'left': { 'left': { 'left': { 'left': -4.1911745, 'spInd': 0, 'spVal': 0.437652, 'right': { 'left': { 'left': { 'left': { 'left': 19.745224, 'spInd': 0, 'spVal': 0.428582, 'right': 15.224266}, 'spInd': 0, 'spVal': 0.426711, 'right': -21.594268}, 'spInd': 0, 'spVal': 0.418943, 'right': 44.161493}, 'spInd': 0, 'spVal': 0.412516, 'right': { 'left': -26.419289, 'spInd': 0, 'spVal': 0.403228, 'right': 0.6359300000000001}}}, 'spInd': 0, 'spVal': 0.388789, 'right': 23.197474}, 'spInd': 0, 'spVal': 0.382037, 'right': { 'left': { 'left': { 'left': -29.007783, 'spInd': 0, 'spVal': 0.378965, 'right': { 'left': { 'left': 13.583555, 'spInd': 0, 'spVal': 0.377383, 'right': 5.241196}, 'spInd': 0, 'spVal': 0.373501, 'right': -8.228297}}, 'spInd': 0, 'spVal': 0.370042, 'right': { 'left': -32.124495, 'spInd': 0, 'spVal': 0.35679, 'right': { 'left': -9.9938275, 'spInd': 0, 'spVal': 0.350725, 'right': -26.851234812500003}}}, 'spInd': 0, 'spVal': 0.335182, 'right': { 'left': 22.286959625, 'spInd': 0, 'spVal': 0.324274, 'right': { 'left': { 'left': -20.3973335, 'spInd': 0, 'spVal': 0.310956, 'right': -49.939516}, 'spInd': 0, 'spVal': 0.309133, 'right': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': { 'left': 8.814725, 'spInd': 0, 'spVal': 0.300318, 'right': { 'left': -18.051318, 'spInd': 0, 'spVal': 0.297107, 'right': { 'left': -1.798377, 'spInd': 0, 'spVal': 0.295993, 'right': { 'left': -14.988279, 'spInd': 0, 'spVal': 0.290749, 'right': -14.391613}}}}, 'spInd': 0, 'spVal': 0.284794, 'right': { 'left': 35.623746, 'spInd': 0, 'spVal': 0.273863, 'right': { 'left': -9.457556, 'spInd': 0, 'spVal': 0.264926, 'right': { 'left': 5.280579, 'spInd': 0, 'spVal': 0.264639, 'right': 2.557923}}}}, 'spInd': 0, 'spVal': 0.25807, 'right': { 'left': { 'left': -9.601409499999999, 'spInd': 0, 'spVal': 0.228751, 'right': -30.812912}, 'spInd': 0, 'spVal': 0.228628, 'right': -2.266273}}, 'spInd': 0, 'spVal': 0.228473, 'right': 6.099239}, 'spInd': 0, 'spVal': 0.211633, 'right': { 'left': -16.42737025, 'spInd': 0, 'spVal': 0.202161, 'right': -2.6781805}}, 'spInd': 0, 'spVal': 0.193282, 'right': 9.5773855}, 'spInd': 0, 'spVal': 0.166765, 'right': { 'left': { 'left': { 'left': -14.740059, 'spInd': 0, 'spVal': 0.166431, 'right': -6.512506}, 'spInd': 0, 'spVal': 0.164134, 'right': -27.405211}, 'spInd': 0, 'spVal': 0.156273, 'right': 0.225886}}, 'spInd': 0, 'spVal': 0.156067, 'right': { 'left': 7.557349, 'spInd': 0, 'spVal': 0.13988, 'right': 7.336784}}, 'spInd': 0, 'spVal': 0.138619, 'right': -29.087463}, 'spInd': 0, 'spVal': 0.131833, 'right': 22.478291}}}}}, 'spInd': 0, 'spVal': 0.130626, 'right': -39.524461}, 'spInd': 0, 'spVal': 0.126833, 'right': { 'left': 22.891675, 'spInd': 0, 'spVal': 0.124723, 'right': { 'left': { 'left': 6.196516, 'spInd': 0, 'spVal': 0.108801, 'right': { 'left': -16.106164, 'spInd': 0, 'spVal': 0.10796, 'right': { 'left': -1.293195, 'spInd': 0, 'spVal': 0.085873, 'right': -10.137104}}}, 'spInd': 0, 'spVal': 0.085111, 'right': { 'left': 37.820659, 'spInd': 0, 'spVal': 0.084661, 'right': { 'left': -24.132226, 'spInd': 0, 'spVal': 0.080061, 'right': { 'left': 15.824970500000001, 'spInd': 0, 'spVal': 0.068373, 'right': { 'left': -15.160836, 'spInd': 0, 'spVal': 0.061219, 'right': { 'left': { 'left': { 'left': 6.695567, 'spInd': 0, 'spVal': 0.055862, 'right': -3.131497}, 'spInd': 0, 'spVal': 0.053764, 'right': -13.731698}, 'spInd': 0, 'spVal': 0.044737, 'right': 4.091626}}}}}}}}}}}
虽然合并了很多叶节点,但剪枝后的树没有像预期的那样剪枝成两部分。说明后剪枝可能不如预剪枝有效。可以同时使用两种剪枝方式。
模型树:把叶子节点设定为分段线性函数。利用数生成算法对数据切分,且每份切分数据容易被线性模型表示。该算法的关键在于误差的计算。
对于给定的数据集,应该先用线性的模型对它拟合,然后计算真是的目标值与模型预测值之间的差值,再将这些差值的平方求和就得到了所需要的误差。
模型树的叶节点生成函数:
def linearSolve(dataSet): m, n = shape(dataSet) X = mat(ones((m, n))) #第一列仍为1 Y = mat(ones((m, 1))) X[:, 1:n] = dataSet[:, 0:n - 1] # print('X:',X) Y = dataSet[:, -1] # 将X,Y中的数据格式化 # print('Y:',Y) xTx = X.T * X if linalg.det(xTx) == 0.0: raise NameError("此矩阵不可逆。") # ws = linalg.pinv(xTx) * (X.T * Y) ws = xTx.I * (X.T * Y) return ws, X, Ydef modelLeaf(dataSet): # 当数据不再需要切分的时候它负责生成叶节点模型 ws, X, Y = linearSolve(dataSet) return wsdef modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return sum(power(Y - yHat, 2))
数据集散点图如下:
测试:
myMat=mat(loadDataSet('exp2.txt')) plotPoint(myMat) myTree=createTree(myMat,modelLeaf,modelErr,(1,10)) print(myTree)
输出结果:
{'spInd': 0, 'spVal': 0.285477, 'right': matrix([[3.46877936], [1.18521743]]), 'left': matrix([[1.69855694e-03], [1.19647739e+01]])}
将数据集从x=0.285477分开,分别用两段线性模型来拟合。
树回归与标准回归的比较:相关系数
用树回归进行预测的代码:包括回归树和模型树两种树
def regTreeEval(model, inDat): #回归树效果评估 return float(model)def modelTreeEval(model, inDat): #模型树效果评估 n = shape(inDat)[1] X = mat(ones((1, n + 1))) X[:, 1:n + 1] = inDat return float(X * model)def treeForeCast(tree, inData, modelEval=regTreeEval): if not isTree(tree): return modelEval(tree, inData) # 如果输入单个数据或行向量,返回一个浮点值 else: if inData[tree["spInd"]] > tree["spVal"]: if isTree(tree["left"]): return treeForeCast(tree["left"], inData, modelEval) else: return modelEval(tree["left"], inData) else: if isTree(tree["right"]): return treeForeCast(tree["right"], inData, modelEval) else: return modelEval(tree["right"], inData)def createForeCast(tree, testData, modelEval=regTreeEval): #测试不同回归树的效果 m = len(testData) yHat = mat(zeros((m, 1))) for i in range(m): yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval) # 多次调用treeForeCast函数,将结果以列的形式放到yHat变量中 return yHat
因为代码中已经含有标准线性回归函数(linearSolve),所以不必重新写其生成代码。
测试:
if __name__=='__main__': trainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt")) testMat = mat(loadDataSet("bikeSpeedVsIq_test.txt")) myTree = createTree(trainMat, ops=(1, 20)) yHat = createForeCast(myTree, testMat[:, 0]) print("回归树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1]) myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20)) yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval) print("模型树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1]) ws, X, Y = linearSolve(trainMat) print("线性回归系数:", ws) for i in range(shape(testMat)[0]): yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0] print("线性回归模型的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])
输出:
回归树的相关系数: 0.964085231822215模型树的相关系数: 0.9760412191380629线性回归系数: [[37.58916794] [ 6.18978355]]线性回归模型的相关系数: 0.9434684235674766
相关系数越接近1越好,所以,模型树>回归树>标准线性回归。