使用mllib完成mnist手写识别任务

使用mllib完成mnist手写识别任务

  1. 小提示,通过restart命令重启已经退出了的容器

    sudo docker restart <contain id>

    请添加图片描述

  2. 完成识别任务准备工作

    1. 从以下网站下载数据集:

      MNIST手写数字数据库,Yann LeCun,Corinna Cortes和Chris Burges

      数据集包含以下四个压缩包,下载后解压得到数据集文件:

      • t10k-images-idx3-ubyte.gz
      • t10k-labels-idx1-ubyte.gz
      • train-images-idx3-ubyte.gz
      • train-labels-idx1-ubyte.gz
    2. 通过以下python程序,将数据集文件转换为csv文件

      def convert(imgf, labelf, outf, n):
          f = open(imgf, "rb")
          o = open(outf, "w")
          l = open(labelf, "rb")
      
          f.read(16)
          l.read(8)
          images = []
      
          for i in range(n):
              image = [ord(l.read(1))]
              for j in range(28 * 28):
                  image.append(ord(f.read(1)))
              images.append(image)
      
          for image in images:
              o.write(",".join(str(pix) for pix in image) + "\n")
          f.close()
          o.close()
          l.close()
      
      
      # 数据集在 http://yann.lecun.com/exdb/mnist/ 下载
      convert("train-images.idx3-ubyte", "train-labels.idx1-ubyte",
              "mnist_train.csv", 60000)
      convert("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte",
              "mnist_test.csv", 10000)
      

      通过这个程序将在根目录下产生以下两个文件:

      • mnist_train.csv
      • mnist_test.csv
    3. 通过以下python程序转换csv文件为libsvm文件

      import csv
      
      
      def execute(data, savepath):
      
          csv_reader = csv.reader(open(data))
          f = open(savepath, 'wb')
          for line in csv_reader:
              label = line[0]
              features = line[1:]
              libsvm_line = label + ' '
      
              for index, feature in enumerate(features):
                  libsvm_line += str(index + 1) + ':' + feature + ' '
              f.write(bytes(libsvm_line.strip() + '\n', 'UTF-8'))
      
          f.close()
      
      
      execute('mnist_train.csv', 'mnist_train.libsvm')
      execute('mnist_test.csv', 'mnist_test.libsvm')
      

      该程序将生成以下两个.libsvm文件:

      • mnist_test.libsvm
      • mnist_train.libsvm
    4. 通过共享目录传递数据集到spark-master容器内。

    5. 进入spark-master

      sudo docker exec -it spark-master /bin/bash

      请添加图片描述

    6. 打开spark-shell

      spark-shell位于/spark/bin目录下

      使用./spark-shell命令进入spark-shell。

      请添加图片描述

  3. 完成识别任务

    1. 读取训练集

      val train = spark.read.format("libsvm").load("/data/mnist_train.libsvm")
      

      请添加图片描述

    2. 读取测试集

      val test = 		spark.read.format("libsvm").load("/data/mnist_test.libsvm")
      

      请添加图片描述

    3. 定义网络结构。如果计算机性能不好可以降低隐藏层的参数。

      val layers = Array[Int](784, 784, 784, 10)
      

      请添加图片描述

    4. 导入多层感知机与多分类评价器。

      import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
      import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
      

      请添加图片描述

    5. 使用多层感知机初始化训练器。

      val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100)
      

      请添加图片描述

    6. 训练模型

      var model = trainer.fit(train)
      

      请添加图片描述

      请添加图片描述

    7. 输入测试集进行识别

      val result = model.transform(test)
      

      请添加图片描述

    8. 获取测试结果中的预测结果与实际结果

      val predictionAndLabels = result.select("prediction", "label")
      

      请添加图片描述

    9. 初始化评价器

      val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
      

      请添加图片描述

    10. 计算识别精度

      println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}")
      

      请添加图片描述

    11. 在result上创建临时视图

      result.toDF.createOrReplaceTempView("deep_learning")
      

      请添加图片描述

    12. 使用Spark SQL的方式计算识别精度

      spark.sql("select (select count(*) from deep_learning where label=prediction)/count(*) as accuracy from deep_learning").show()
      

      请添加图片描述

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值