使用Decision Tree对MNIST数据集进行实验

使用的Decision Tree中,对MNIST中的灰度值进行了0/1处理,方便来进行分类和计算熵。

使用较少的测试数据测试了在对灰度值进行多分类的情况下,分类结果的正确率如何。实验结果如下。

#Test change pixel data into more categories than 0/1:
#int(pixel)/50: 37%
#int(pixel)/64: 45.9%
#int(pixel)/96: 52.3%
#int(pixel)/128: 62.48%
#int(pixel)/152: 59.1%
#int(pixel)/176: 57.6%
#int(pixel)/192: 54.0%

可见,在对灰度数据进行二分类,也就是0/1处理时,效果是最好的。

使用0/1处理,最终结果如下:

#Result:
#Train with 10k, test with 60k: 77.79%
#Train with 60k, test with 10k: 87.3%
#Time cost: 3 hours.

最终结果是87.3%的正确率。与SVM和KNN的超过95%相比,差距不小。而且消耗时间更长。

需要注意的是,此次Decision Tree算法中,并未对决策树进行剪枝。因此,还有可以提升的空间。

python代码见最下面。其中:

calcShannonEntropy(dataSet):是对矩阵的熵进行计算,根据各个数据点的分类情况,使用香农定理计算;

splitDataSet(dataSet, axis, value): 是获取第axis维度上的值为value的所有行所组成的矩阵。对于第axis维度上的数据,分别计算他们的splitDataSet的矩阵的熵,并与该维度上数据的出现概率相乘求和,可以得到使用第axis维度构建决策树后,整体的熵。

chooseBestFeatureToSplit(dataSet): 根据splitDataSet函数,对比得到整体的熵与原矩阵的熵相比,熵的增量最大的维度。根据此维度feature来构建决策树。

createDecisionTree(dataSet, features): 递归构建决策树。若在叶子节点处没法分类,则采用majorityCnt(classList)方法统计出现最多次的class作为分类。

代码如下:

  1. #Decision tree for MNIST dataset by arthur503.  
  2. #Data format: 'class    label1:pixel    label2:pixel ...'  
  3. #Warning: without fix overfitting!  
  4. #  
  5. #Test change pixel data into more categories than 0/1:  
  6. #int(pixel)/50: 37%  
  7. #int(pixel)/64: 45.9%  
  8. #int(pixel)/96: 52.3%  
  9. #int(pixel)/128: 62.48%  
  10. #int(pixel)/152: 59.1%  
  11. #int(pixel)/176: 57.6%  
  12. #int(pixel)/192: 54.0%  
  13. #  
  14. #Result:  
  15. #Train with 10k, test with 60k: 77.79%  
  16. #Train with 60k, test with 10k: 87.3%  
  17. #Time cost: 3 hours.  
  18.   
  19. from numpy import *  
  20. import operator  
  21.   
  22. def calcShannonEntropy(dataSet):  
  23.     numEntries = len(dataSet)  
  24.     labelCounts = {}  
  25.     for featureVec in dataSet:  
  26.         currentLabel = featureVec[0]  
  27.         if currentLabel not in labelCounts.keys():  
  28.             labelCounts[currentLabel] = 1  
  29.         else:  
  30.             labelCounts[currentLabel] += 1  
  31.     shannonEntropy = 0.0  
  32.     for key in labelCounts:  
  33.         prob = float(labelCounts[key])/numEntries  
  34.         shannonEntropy -= prob  * log2(prob)  
  35.     return shannonEntropy  
  36.   
  37. #get all rows whose axis item equals value.  
  38. def splitDataSet(dataSet, axis, value):  
  39.     subDataSet = []  
  40.     for featureVec in dataSet:  
  41.         if featureVec[axis] == value:  
  42.             reducedFeatureVec = featureVec[:axis]  
  43.             reducedFeatureVec.extend(featureVec[axis+1:])   #if axis == -1, this will cause error!  
  44.             subDataSet.append(reducedFeatureVec)  
  45.     return subDataSet  
  46.   
  47. def chooseBestFeatureToSplit(dataSet):  
  48.     #Notice: Actucally, index 0 of numFeatures is not feature(it is class label).  
  49.     numFeatures = len(dataSet[0])     
  50.     baseEntropy = calcShannonEntropy(dataSet)  
  51.     bestInfoGain = 0.0  
  52.     bestFeature = numFeatures - 1   #DO NOT use -1! or splitDataSet(dataSet, -1, value) will cause error!  
  53.     #feature index start with 1(not 0)!  
  54.     for i in range(numFeatures)[1:]:  
  55.         featureList = [example[i] for example in dataSet]  
  56.         featureSet = set(featureList)  
  57.         newEntropy = 0.0  
  58.         for value in featureSet:  
  59.             subDataSet = splitDataSet(dataSet, i, value)  
  60.             prob = len(subDataSet)/float(len(dataSet))  
  61.             newEntropy += prob * calcShannonEntropy(subDataSet)  
  62.         infoGain = baseEntropy - newEntropy  
  63.         if infoGain > bestInfoGain:  
  64.             bestInfoGain = infoGain  
  65.             bestFeature = i  
  66.     return bestFeature  
  67.   
  68. #classify on leaf of decision tree.  
  69. def majorityCnt(classList):  
  70.     classCount = {}  
  71.     for vote in classList:  
  72.         if vote not in classCount:  
  73.             classCount[vote] = 0  
  74.         classCount[vote] += 1  
  75.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
  76.     return sortedClassCount[0][0]  
  77.   
  78. #Create Decision Tree.  
  79. def createDecisionTree(dataSet, features):  
  80.     print 'create decision tree... length of features is:'+str(len(features))  
  81.     classList = [example[0] for example in dataSet]  
  82.     if classList.count(classList[0]) == len(classList):  
  83.         return classList[0]  
  84.     if len(dataSet[0]) == 1:  
  85.         return majorityCnt(classList)  
  86.     bestFeatureIndex = chooseBestFeatureToSplit(dataSet)   
  87.     bestFeatureLabel = features[bestFeatureIndex]  
  88.     myTree = {bestFeatureLabel:{}}  
  89.     del(features[bestFeatureIndex])  
  90.     featureValues = [example[bestFeatureIndex] for example in dataSet]  
  91.     featureSet = set(featureValues)  
  92.     for value in featureSet:  
  93.         subFeatures = features[:]     
  94.         myTree[bestFeatureLabel][value] = createDecisionTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures)  
  95.     return myTree  
  96.   
  97. def line2Mat(line):  
  98.     mat = line.strip().split(' ')  
  99.     for i in range(len(mat)-1):   
  100.         pixel = mat[i+1].split(':')[1]  
  101.         #change MNIST pixel data into 0/1 format.  
  102.         mat[i+1] = int(pixel)/128  
  103.     return mat  
  104.   
  105. #return matrix as a list(instead of a matrix).  
  106. #features is the 28*28 pixels in MNIST dataset.  
  107. def file2Mat(fileName):  
  108.     f = open(fileName)  
  109.     lines = f.readlines()  
  110.     matrix = []  
  111.     for line in lines:  
  112.         mat = line2Mat(line)  
  113.         matrix.append(mat)  
  114.     f.close()  
  115.     print 'Read file '+str(fileName) + ' to array done! Matrix shape:'+str(shape(matrix))  
  116.     return matrix  
  117.   
  118. #Classify test file.  
  119. def classify(inputTree, featureLabels, testVec):  
  120.     firstStr = inputTree.keys()[0]  
  121.     secondDict = inputTree[firstStr]  
  122.     featureIndex = featureLabels.index(firstStr)  
  123.     predictClass = '-1'  
  124.     for key in secondDict.keys():  
  125.         if testVec[featureIndex] == key:  
  126.             if type(secondDict[key]) == type({}):     
  127.                 predictClass = classify(secondDict[key], featureLabels, testVec)  
  128.             else:  
  129.                 predictClass = secondDict[key]  
  130.     return predictClass  
  131.   
  132. def classifyTestFile(inputTree, featureLabels, testDataSet):  
  133.     rightCnt = 0  
  134.     for i in range(len(testDataSet)):  
  135.         classLabel = testDataSet[i][0]  
  136.         predictClassLabel = classify(inputTree, featureLabels, testDataSet[i])  
  137.         if classLabel == predictClassLabel:  
  138.             rightCnt += 1   
  139.         if i % 200 == 0:  
  140.             print 'num '+str(i)+'. ratio: ' + str(float(rightCnt)/(i+1))  
  141.     return float(rightCnt)/len(testDataSet)  
  142.   
  143. def getFeatureLabels(length):  
  144.     strs = []  
  145.     for i in range(length):  
  146.         strs.append('#'+str(i))  
  147.     return strs  
  148.   
  149. #Normal file  
  150. trainFile = 'train_60k.txt'   
  151. testFile = 'test_10k.txt'  
  152. #Scaled file  
  153. #trainFile = 'train_60k_scale.txt'  
  154. #testFile = 'test_10k_scale.txt'  
  155. #Test file  
  156. #trainFile = 'test_only_1.txt'    
  157. #testFile = 'test_only_2.txt'  
  158.   
  159. #train decision tree.  
  160. dataSet = file2Mat(trainFile)  
  161. #Actually, the 0 item is class, not feature labels.  
  162. featureLabels = getFeatureLabels(len(dataSet[0]))     
  163. print 'begin to create decision tree...'  
  164. myTree = createDecisionTree(dataSet, featureLabels)  
  165. print 'create decision tree done.'  
  166.   
  167. #predict with decision tree.      
  168. testDataSet = file2Mat(testFile)  
  169. featureLabels = getFeatureLabels(len(testDataSet[0]))     
  170. rightRatio = classifyTestFile(myTree, featureLabels, testDataSet)  
  171. print 'total right ratio: ' + str(rightRatio)  

 

 

转载于:https://www.cnblogs.com/wt869054461/p/5030853.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值