【统计分析数学模型】判别分析(四):机器学习分类算法
一、机器学习分类算法
机器学习中的分类算法也常常用来解决判别分析问题。常见的分类算法包括决策树、K最邻近、支持向量机、神经网络、随机森林等。
1. 交叉验证方法
在用这些算法建立分类模型时,如果用全部数据建立模型并用回代法进行模型的内部验证,可能会出现过度拟合现象。因此,在建模时需要进行交叉验证以避免过度拟合问题。常用的交叉验证方法有以下3种:
- 保留交叉验证(hand-out cross validation)
将样本集随机分成训练集(training set)和验证集(test set),比例通常是7∶3或8∶2。使用模型在训练集上学习得到假设,然后使用验证集对假设进行验证,看模型预测的准确性,选择误差小的模型。 - k折交叉验证(k-fold cross validation)
把样本集分成k份,分别使用其中的k−1份作为训练集,剩下的1份作为交叉验证集,最后通过所有模型的平均误差来评估模型参数。 - 留一法验证(leave-one-out validation)
实质上是n折交叉验证,n是样本集的大小,就是只留下一个样本来验证模型的准确性。
2. 案例数据集
以MASS包中的数据集Pima.tr和Pima.te为例说明常见机器学习分类算法的实现方法。
它们是居住在美国某地区皮马印第安人后裔中部分女性的糖尿病调查数据。两个数据框中的变量都是一样的,其含义如下。
- npreg:怀孕次数
- glu:血糖浓度
- bp:舒张压(单位:mmHg)
- skin:三头肌皮褶厚度(单位:mm)
- bmi:体质指数
- ped:糖尿病家族史因素
- age:年龄
- type:是否患有糖尿病(Yes/No)
其中,结局变量为type,其余均为数值型的预测变量。
3. 数据标准化
首先,加载这两个数据集并分别将它们作为训练集和测试集:
library(MASS)
data(Pima.tr)
data(Pima.te)
由于预测变量的测量单位之间有较大差异,下面用函数scale()将它们标准化,都转换为均值为0、标准差为1的变量,并将训练集命名为data.train,测试集命名为data.test:
data.train<-Pima.tr
data.train[,-8]<-scale(data.train[,-8])
data.test<-Pima.te
data.test[,-8]<-scale(data.test[,-8])
二、决策树模型
1. 基本原理
决策树(decision tree)模型 是一种简单易用的非参数分类方法。它不需要对数据的分布有任何的先验假设,计算速度快,结果也容易解释。
分类回归树方法(CART) 是决策树模型中的一种经典算法,CART分为分类树(classification tree)和回归树(regression tree)两种。
分类树用于因变量为分类数据的情况,树的末端为因变量的类别;回归树用于因变量为连续型变量的情况,树的末端给出相应类别中的因变量描述或预测。
2. 计算步骤
- 首先对所有自变量和所有分隔点进行评估,最佳的选择是使分隔后组内的数据“纯度”更高,即组内目标变量的变异最小;
- 再对分类树模型进行修剪或称为剪枝:如果不加任何限制,过度复杂的分类树模型很容易产生“过度拟合”的问题。
通常使用CP参数(complexity parameter)控制树的复杂度。CP参数取值越小,模型越复杂,越偏向于过度拟合。
通常的做法是先建立一个枝节较多的分类树模型,再使用交叉验证的方法来估计不同“剪枝”条件下各个模型的误差,从而选择误差最小的分类树模型。
3. R语言实现
library(rpart)
set.seed(123)
pima.rpart<-rpart(type~.,data=data.train,control=rpart.control(cp=0))
pima.rpart
决策树模型的输出结果看起来像是以树状形式排列的一系列if-else语句。每行括号前面的数字代表节点,行缩进表示分支,*表示叶节点,loss表示误差数量。各节点后括号里的数值代表了各类别的比例:
> pima.rpart
n= 200
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 200 68 No (0.66000000 0.34000000)
2) glu< -0.01484184 109 15 No (0.86238532 0.13761468)
4) age< -0.3289163 74 4 No (0.94594595 0.05405405) *
5) age>=-0.3289163 35 11 No (0.68571429 0.31428571)
10) glu< -1.072718 9 0 No (1.00000000 0.00000000) *
11) glu>=-1.072718 26 11 No (0.57692308 0.42307692)
22) bp>=-0.2839819 19 6 No (0.68421053 0.31578947) *
23) bp< -0.2839819 7 2 Yes (0.28571429 0.71428571) *
3) glu>=-0.01484184 91 38 Yes (0.41758242 0.58241758)
6) ped< -0.4923593 35 12 No (0.65714286 0.34285714)
12) glu< 1.32724 27 6 No (0.77777778 0.22222222) *
13) glu>=1.32724 8 2 Yes (0.25000000 0.75000000) *
7) ped>=-0.4923593 56 15 Yes (0.26785714 0.73214286)
14) bmi< -0.597043 11 3 No (0.72727273 0.27272727) *
15) bmi>=-0.597043 45 7 Yes (0.15555556 0.84444444) *
用plot()函数绘制分类树:
plot(pima.rpart,uniform=TRUE,margin=0.1)
text(pima.rpart,use.n=TRUE