Tensorflowsharp从入门到放弃(二)——这次有个手写数字识别

接上文,源代码是控制台的:所有输出Console.WriteLine(*);这样的代码一致改为this.textBox1.Text +="\r\n"+ string.Format(*);

这次又更新了四课内容,其中手写数字识别卡住了一会,主要原因是网上 TF#的MNIST手写数字识别代码太多太乱,此处补上最新版的(也是原作者的代码)供后人学习。


        /// <summary>
        /// 06 线性回归(没实现)
        /// </summary>
        private void button6_Click(object sender, EventArgs e)
        {
            // 创建所需数据
            var xList = new List<double>();
            var yList = new List<double>();
            var ran = new Random();
            for (var i = 0; i < 10; i++)
            {
                var num = ran.NextDouble();
                var noise = ran.NextDouble();
                xList.Add(num);
                yList.Add(num * 3 + 4 + noise); // y = 3 * x + 4
            }
            var xData = xList.ToArray();
            var yData = yList.ToArray();
            var learning_rate = 0.01;

            // 创建图
            var g = new TFGraph();

            // 创建占位符
            var x = g.Placeholder(TFDataType.Double, new TFShape(xData.Length));
            var y = g.Placeholder(TFDataType.Double, new TFShape(yData.Length));

            // 权重和偏置
            var W = g.VariableV2(TFShape.Scalar, TFDataType.Double, operName: "weight");
            var b = g.VariableV2(TFShape.Scalar, TFDataType.Double, operName: "bias");

            var initW = g.Assign(W, g.Const(ran.NextDouble()));
            var initb = g.Assign(b, g.Const(ran.NextDouble()));

            var output = g.Add(g.Mul(x, W), b);

            // 损失
            var loss = g.ReduceSum(g.Abs(g.Sub(output, y)));
            var grad = g.AddGradients(new TFOutput[] { loss }, new TFOutput[] { W, b });
            var optimize = new[]
            {
                g.AssignSub(W, g.Mul(grad[0], g.Const(learning_rate))).Operation,
                g.AssignSub(b, g.Mul(grad[1], g.Const(learning_rate))).Operation
            };

            // 创建会话
            var sess = new TFSession(g);

            // 变量初始化
            sess.GetRunner().AddTarget(initW.Operation, initb.Operation).Run();

            // 进行训练拟合
            for (var i = 0; i < 1000; i++)
            {
                var result = sess.GetRunner()
                    .AddInput(x, xData)
                    .AddInput(y, yData)
                    .AddTarget(optimize)
                    .Fetch(loss, W, b).Run();

                this.textBox1.Text = string.Format("loss: {0} W:{1} b:{2}", result[0].GetValue(), result[1].GetValue(), result[2].GetValue());
               
            }
        }

        /// <summary>
        /// 07 手写数字识别
        /// </summary>
        private void button7_Click(object sender, EventArgs e)
        {
            // 加载手写数字资源
            var mnist = Mnist.Load();

            // 训练次数和测试次数
            var trainCount = 5000;
            var testCount = 200;

            // 获取训练图片、训练图片标签、测试图片、测试图片标签
            float[,] trainingImages, trainingLabels, testImages, testLabels;
            mnist.GetTrainReader().NextBatch(trainCount, out trainingImages, out trainingLabels);
            mnist.GetTestReader().NextBatch(testCount, out testImages, out testLabels);
            //mnist.GetTrainReader().NextBatch(trainCount);
            //mnist.GetTestReader().NextBatch(testCount);


            // 创建图
            var g = new TFGraph();

            // 训练图片占位符和训练标签占位符
            var trainingInput = g.Placeholder(TFDataType.Float, new TFShape(-1, 784)); // 不定数量的像素为24*24的图片
            var xte = g.Placeholder(TFDataType.Float, new TFShape(784));

            // 创建计算误差和预测的图
            var distance = g.ReduceSum(g.Abs(g.Add(trainingInput, g.Neg(xte))), axis: g.Const(1));
            var pred = g.ArgMin(distance, g.Const(0));

            // 创建会话
            var sess = new TFSession(g);

            // 精度
            var accuracy = 0.0f;

            // 进行迭代训练,并且每次都输出预测值
            for (int i = 0; i < testCount; i++)
            {
                var runner = sess.GetRunner();

                // 计算并且获取误差和预测值
                var result = runner.
                    Fetch(pred).
                    Fetch(distance).
                    AddInput(trainingInput, trainingImages).
                    AddInput(xte, Extract(testImages, i)).Run();
                var r = result[0].GetValue();
                var tr = result[1].GetValue();

                var nn_index = (int)(long)result[0].GetValue();

                this.textBox1.Text += string.Format($"训练次数 {i}: 预测: { ArgMax(trainingLabels, nn_index) } 真实值: { ArgMax(testLabels, i)} (nn_index= { nn_index })");
                if (ArgMax(trainingLabels, nn_index) == ArgMax(testLabels, i))
                    accuracy += 1f / testImages.Length;
            }

            // 精确度
            this.textBox1.Text += string.Format("精度:" + accuracy);
        }

        /// <summary>
        /// 获取矩阵array中idx行的最大值
        /// </summary>
        /// <param name="array"></param>
        /// <param name="idx"></param>
        /// <returns></returns>
        static int ArgMax(float[,] array, int idx)
        {
            float max = -1;
            int maxIdx = -1;
            var len = array.GetLength(1);
            for (int i = 0; i < len; i++)
                if (array[idx, i] > max)
                {
                    maxIdx = i;
                    max = array[idx, i];
                }
            return maxIdx;
        }

        /// <summary>
        /// 获取矩阵array中的index行(即获取n*n图片数组中的第n张)
        /// </summary>
        /// <param name="array"></param>
        /// <param name="index"></param>
        /// <returns></returns>
        static public float[] Extract(float[,] array, int index)
        {
            var n = array.GetLength(1);
            var ret = new float[n];

            for (int i = 0; i < n; i++)
                ret[i] = array[index, i];
            return ret;
        }

        /// <summary>
        /// 08 张量的使用
        /// </summary>
        private void button8_Click(object sender, EventArgs e)
        {
            // 整数张量
            var tensor = new TFTensor(1);
            this.textBox1.Text +="\r\n"+ string.Format("Value:" + tensor.GetValue());

            // 矩阵张量
            tensor = new TFTensor(new int[,]
            {
                { 1, 2, 3 },
                { 4, 5, 6 },
            });
            this.textBox1.Text += "\r\n" + string.Format("TensorType: {0}", tensor.TensorType.ToString());
            this.textBox1.Text += "\r\n" + string.Format("NumDims: " + tensor.NumDims);
            this.textBox1.Text += "\r\n" + string.Format("Shape", string.Join(",", tensor.Shape));
            for (var i = 0; i < tensor.NumDims; i++)
            {
                var dim = tensor.GetTensorDimension(i);
                this.textBox1.Text += "\r\n" + string.Format("DimIndex: {0}, Dim: {1}", i, dim);
            }

            // 创建图
            var g = new TFGraph();

            // 创建字符串张量
            tensor = new TFTensor("Hello, world!".Select(o => (sbyte)o).ToArray());
            var hello = g.Const(tensor);

            // 创建会话
            var sess = new TFSession(g);

            // 进行计算
            var result = sess.GetRunner().Run(hello).GetValue();

            // 输出计算结果
            this.textBox1.Text += "\r\n" + string.Format(string.Join("", ((sbyte[])result).Select(o => (char)o)));
        }

以下三个类文件放到自己工程下,网上很多MNIST.cs的代码有些版本不同,会有部分函数功能不同,目前最新如下:

MNIST.cs代码

//
// Code to download and load the MNIST data.
//

using System;
using System.IO;
using System.IO.Compression;
using Mono;
using TensorFlow;
using System.Linq;

namespace Learn.Mnist
{
    // Stores the per-image MNIST information we loaded from disk 
    //
    // We store the data in two formats, byte array (as it came in from disk), and float array
    // where each 0..255 value has been mapped to 0.0f..1.0f
    public struct MnistImage
    {
        public int Cols, Rows;
        public byte[] Data;
        public float[] DataFloat;

        public MnistImage(int cols, int rows, byte[] data)
        {
            Cols = cols;
            Rows = rows;
            Data = data;
            DataFloat = new float[data.Length];
            for (int i = 0; i < data.Length; i++)
            {
                DataFloat[i] = Data[i] / 255f;
            }
        }
    }

    // Helper class used to load and work with the Mnist data set
    public class Mnist
    {
        // 
        // The loaded results
        //
        public MnistImage[] TrainImages, TestImages, ValidationImages;
        public byte[] TrainLabels, TestLabels, ValidationLabels;
        public byte[,] OneHotTrainLabels, OneHotTestLabels, OneHotValidationLabels;

        public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
        public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
        public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);

        public class BatchReader
        {
            int start = 0;
            MnistImage[] source;
            byte[] labels;
            byte[,] oneHotLabels;

            internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
            {
                this.source = source;
                this.labels = labels;
                this.oneHotLabels = oneHotLabels;
            }

            public void NextBatch(int batchSize, out float[,] imageData, out float[,] labelData)
            {
                imageData = new float[batchSize, 784];
                labelData = new float[batchSize, 10];

                int p = 0;
                for (int item = 0; item < batchSize; item++)
                {
                    Buffer.BlockCopy(source[start + item].DataFloat, 0, imageData, p, 784 * sizeof(float));
                    p += 784 * sizeof(float);
                    for (var j = 0; j < 10; j++)
                        labelData[item, j] = oneHotLabels[item + start, j];
                }

                start += batchSize;
            }
        }

        int Read32(Stream s)
        {
            var x = new byte[4];
            s.Read(x, 0, 4);
            return DataConverter.BigEndian.GetInt32(x, 0);
        }

        MnistImage[] ExtractImages(Stream input, string file)
        {
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                if (Read32(gz) != 2051)
                    throw new Exception("Invalid magic number found on the MNIST " + file);
                var count = Read32(gz);
                var rows = Read32(gz);
                var cols = Read32(gz);

                var result = new MnistImage[count];
                for (int i = 0; i < count; i++)
                {
                    var size = rows * cols;
                    var data = new byte[size];
                    gz.Read(data, 0, size);

                    result[i] = new MnistImage(cols, rows, data);
                }
                return result;
            }
        }


        byte[] ExtractLabels(Stream input, string file)
        {
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                if (Read32(gz) != 2049)
                    throw new Exception("Invalid magic number found on the MNIST " + file);
                var count = Read32(gz);
                var labels = new byte[count];
                gz.Read(labels, 0, count);

                return labels;
            }
        }

        T[] Pick<T>(T[] source, int first, int last)
        {
            if (last == 0)
                last = source.Length;
            var count = last - first;
            var result = new T[count];
            Array.Copy(source, first, result, 0, count);
            return result;
        }

        // Turn the labels array that contains values 0..numClasses-1 into
        // a One-hot encoded array
        byte[,] OneHot(byte[] labels, int numClasses)
        {
            var oneHot = new byte[labels.Length, numClasses];
            for (int i = 0; i < labels.Length; i++)
            {
                oneHot[i, labels[i]] = 1;
            }
            return oneHot;
        }

        /// <summary>
        /// Reads the data sets.
        /// </summary>
        /// <param name="trainDir">Directory where the training data is downlaoded to.</param>
        /// <param name="numClasses">Number classes to use for one-hot encoding, or zero if this is not desired</param>
        /// <param name="validationSize">Validation size.</param>
        public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000)
        {
            const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
            const string TrainImagesName = "train-images-idx3-ubyte.gz";
            const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
            const string TestImagesName = "t10k-images-idx3-ubyte.gz";
            const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";

            TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
            TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
            TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
            TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);

            ValidationImages = Pick(TrainImages, 0, validationSize);
            ValidationLabels = Pick(TrainLabels, 0, validationSize);
            TrainImages = Pick(TrainImages, validationSize, 0);
            TrainLabels = Pick(TrainLabels, validationSize, 0);

            if (numClasses != -1)
            {
                OneHotTrainLabels = OneHot(TrainLabels, numClasses);
                OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
                OneHotTestLabels = OneHot(TestLabels, numClasses);
            }
        }

        public static Mnist Load()
        {
            var x = new Mnist();
            x.ReadDataSets(Environment.CurrentDirectory + "\\tmp");
            return x;
        }
    }
}

DataConverter.cs文件

//
// Authors:
//   Miguel de Icaza (miguel@novell.com)
//
// See the following url for documentation:
//     http://www.mono-project.com/Mono_DataConvert
//
// Compilation Options:
//     MONO_DATACONVERTER_PUBLIC:
//         Makes the class public instead of the default internal.
//
//     MONO_DATACONVERTER_STATIC_METHODS:     
//         Exposes the public static methods.
//
// TODO:
//   Support for "DoubleWordsAreSwapped" for ARM devices
//
// Copyright (C) 2006 Novell, Inc (http://www.novell.com)
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software withou
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值