一、交叉验证
在建立分类模型时,交叉验证(Cross Validation)简称为CV,CV是用来验证分类器的性能。它的主体思想是将原始数据进行分组,一部分作为训练集,一部分作为验证集。利用训练集训练出模型,利用验证集来测试模型,以评估分类模型的性能。
二、交叉验证的作用
- 验证分类器的性能
- 用于模型的选择
三、交叉验证常用的几种方法
3.1 k折交叉验证 K-fold Cross Validation(记为K-CV)
1、将数据集平均分割成K个等份(参数cv值,一般选择5折10折,即测试集为20%)
2、使用1份数据作为测试数据,其余作为训练数据
3、计算测试准确率
4、使用不同的测试集,重复2、3步
5、对测试准确率做平均,作为对未知数据预测准确率的估计
优点:
因为每一个样本数据既可以作为测试集又可以作为训练集,可有效避免欠学习和过学习状态的发生,得到的结果比较有说服力。
3.2 留一法交叉验证 Leave-One-Out Cross Validation(记为LOO-CV)
假设样本数据集中有N个样本数据。将每个样本单独作为测试集,其余N-1个样本作为训练集,这样得到了N个分类器或模型,用这N个分类器或模型的分类准确率的平均数作为此分类器的性能指标。
优点:
a. 每一个分类器或模型几乎所有的样本都用来作为训练模型,因此最接近样本,实验评估可靠;
b. 实验过程没有随机因素影响实验结果,所以实验结果可复制,因此实验结果稳定。
缺点:
计算成本高,因为需要建立的模型数量与样本数据数量相同,当N很大时,计算相当耗时。
3.3 留p交叉验证
留p验证指训练集上随机选择p个样本作为测试集,其余作为子训练集。时间复杂度为CpN,是阶乘的复杂度,不可取。
3.4 重复随机子抽样验证 Hold-Out Method
将数据集随机划分为训练集和测试集。对每一个划分,用训练集训练分类器或模型,用测试集评估预测的精确度。进行多次划分,用均值来表示效能。
优点:
与K值无关。严格意义来说Hold-Out Method不属于交叉验证方法,这种方法与k无关。
缺点:
验证集结果准确率的高低和原始分组有很大关系,可能导致一些数据从未做过训练或测试数据;而一些数据不止一次选为训练或测试数据的情况发生,因此结果不具有说服力。
四、交叉验证函数
cross_val_score详情可见官网
train_test_split
#导入
from sklearn.cross_validation import cross_val_score
from sklearn.cross_validation import train_test_split
五、代码
例子:垃圾邮件分类
input:
from numpy import *
from sklearn import metrics
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import GaussianNB as NB
from sklearn.neighbors import KNeighborsClassifier as KNN
from sklearn.linear_model import LogisticRegression as LR
#将词条合并为一个列表
def createVocabList(dataSet):
vocabSet = set([]) #创建一个空集
for document in dataSet:
vocabSet = vocabSet | set(document) #创建两个集合的并集
return list(vocabSet)
#将词汇转化为向量
def bagOfWords2VecMN(vocabList, inputSet):
returnVec = [0]*len(vocabList) #初始化 词汇等长的0向量
for word in inputSet:
if word in vocabList:
returnVec[vocabList.index(word)] += 1
return returnVec
#预处理 统一小写,去除长度小于2个的词汇
def textParse(bigString):
import re
listOfTokens = re.split(r'\W*', bigString)
return [tok.lower() for tok in listOfTokens if len(tok) > 2]
#统计词频前10
def calcMostFreq(vocabList,fullText):
import operator
freqDict = {}
for token in vocabList:
freqDict[token]=fullText.count(token)
sortedFreq = sorted(freqDict.items(), key=operator.itemgetter(1), reverse=True)
return sortedFreq[:10]
#读取数据
def spamTest():
docList=[]; classList = []; fullText =[]
for i in range(1,26):
wordList = textParse(open('email/spam/%d.txt' % i).read())
docList.append(wordList)
fullText.extend(wordList)
classList.append(1)
wordList = textParse(open('email/ham/%d.txt' % i).read())
docList.append(wordList)
fullText.extend(wordList)
classList.append(0)
vocabList = createVocabList(docList) #创建词列表
top10Words = calcMostFreq(vocabList,fullText) #删除词频前10
for pairW in top10Words:
if pairW[0] in vocabList: vocabList.remove(pairW[0])
trainingSet = list(range(50)) #0-49,,50个数字,50封邮件
train_data = [] #存储 所有训练词汇的向量
train_target = [] #存储 类别标签
for docIndex in trainingSet: #得到训练数据的向量
train_data.append(bagOfWords2VecMN(vocabList, docList[docIndex]))
train_target.append(classList[docIndex])
return train_data,train_target
5.1 cross_val_score
input:
from sklearn.cross_validation import cross_val_score
if __name__ == '__main__':
data = []
target = []
data, target = spamTest()
clf1 = KNN(n_neighbors=8)
clf2 = LR()
clf3 = NB()
#交叉验证 cv:数据分成的份数,其中一份作为cv集,其余n-1作为训练集(默认为3)
for clf,lable in zip([clf1, clf2, clf3],['KNN','LR','NB']):
scores = cross_val_score(clf,data,target,cv=5,scoring='accuracy')
#print(scores)
print("Accuracy:%0.2f (+/-%0.2f)[%s]"%(scores.mean(),scores.std(),lable)) #计算均值及标准差
output:
G:\Anacanda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
return _compile(pattern, flags).split(string, maxsplit)
Accuracy:0.64 (+/-0.05)[KNN]
Accuracy:0.94 (+/-0.05)[LR]
Accuracy:0.92 (+/-0.07)[NB]
5.2 train_test_split
input:
from sklearn.cross_validation import train_test_split
if __name__ == '__main__':
data = []
target = []
data, target = spamTest()
clf1 = KNN(n_neighbors=8)
clf2 = LR()
clf3 = NB()
'''
#交叉验证 cv:数据分成的份数,其中一份作为cv集,其余n-1作为训练集(默认为3)
for clf,lable in zip([clf1, clf2, clf3],['KNN','LR','NB']):
scores = cross_val_score(clf,data,target,cv=5,scoring='accuracy')
#print(scores)
print("Accuracy:%0.2f (+/-%0.2f)[%s]"%(scores.mean(),scores.std(),lable)) #计算均值及标准差
'''
#交叉验证
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.2) #交叉验证 20%选取测试集
clf = clf2.fit(x_train, y_train)
predicted = clf.predict(x_test)
expected = y_test
print(metrics.classification_report(expected, predicted))
print(metrics.confusion_matrix(expected, predicted))
print('Score:',accuracy_score(expected,predicted))
output:
G:\Anacanda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match.
return _compile(pattern, flags).split(string, maxsplit)
precision recall f1-score support
0 1.00 1.00 1.00 5
1 1.00 1.00 1.00 5
avg / total 1.00 1.00 1.00 10
[[5 0]
[0 5]]
Score: 1.0