接上文,源代码是控制台的:所有输出Console.WriteLine(*);这样的代码一致改为this.textBox1.Text +="\r\n"+ string.Format(*);
这次又更新了四课内容,其中手写数字识别卡住了一会,主要原因是网上 TF#的MNIST手写数字识别代码太多太乱,此处补上最新版的(也是原作者的代码)供后人学习。
/// <summary>
/// 06 线性回归(没实现)
/// </summary>
private void button6_Click(object sender, EventArgs e)
{
// 创建所需数据
var xList = new List<double>();
var yList = new List<double>();
var ran = new Random();
for (var i = 0; i < 10; i++)
{
var num = ran.NextDouble();
var noise = ran.NextDouble();
xList.Add(num);
yList.Add(num * 3 + 4 + noise); // y = 3 * x + 4
}
var xData = xList.ToArray();
var yData = yList.ToArray();
var learning_rate = 0.01;
// 创建图
var g = new TFGraph();
// 创建占位符
var x = g.Placeholder(TFDataType.Double, new TFShape(xData.Length));
var y = g.Placeholder(TFDataType.Double, new TFShape(yData.Length));
// 权重和偏置
var W = g.VariableV2(TFShape.Scalar, TFDataType.Double, operName: "weight");
var b = g.VariableV2(TFShape.Scalar, TFDataType.Double, operName: "bias");
var initW = g.Assign(W, g.Const(ran.NextDouble()));
var initb = g.Assign(b, g.Const(ran.NextDouble()));
var output = g.Add(g.Mul(x, W), b);
// 损失
var loss = g.ReduceSum(g.Abs(g.Sub(output, y)));
var grad = g.AddGradients(new TFOutput[] { loss }, new TFOutput[] { W, b });
var optimize = new[]
{
g.AssignSub(W, g.Mul(grad[0], g.Const(learning_rate))).Operation,
g.AssignSub(b, g.Mul(grad[1], g.Const(learning_rate))).Operation
};
// 创建会话
var sess = new TFSession(g);
// 变量初始化
sess.GetRunner().AddTarget(initW.Operation, initb.Operation).Run();
// 进行训练拟合
for (var i = 0; i < 1000; i++)
{
var result = sess.GetRunner()
.AddInput(x, xData)
.AddInput(y, yData)
.AddTarget(optimize)
.Fetch(loss, W, b).Run();
this.textBox1.Text = string.Format("loss: {0} W:{1} b:{2}", result[0].GetValue(), result[1].GetValue(), result[2].GetValue());
}
}
/// <summary>
/// 07 手写数字识别
/// </summary>
private void button7_Click(object sender, EventArgs e)
{
// 加载手写数字资源
var mnist = Mnist.Load();
// 训练次数和测试次数
var trainCount = 5000;
var testCount = 200;
// 获取训练图片、训练图片标签、测试图片、测试图片标签
float[,] trainingImages, trainingLabels, testImages, testLabels;
mnist.GetTrainReader().NextBatch(trainCount, out trainingImages, out trainingLabels);
mnist.GetTestReader().NextBatch(testCount, out testImages, out testLabels);
//mnist.GetTrainReader().NextBatch(trainCount);
//mnist.GetTestReader().NextBatch(testCount);
// 创建图
var g = new TFGraph();
// 训练图片占位符和训练标签占位符
var trainingInput = g.Placeholder(TFDataType.Float, new TFShape(-1, 784)); // 不定数量的像素为24*24的图片
var xte = g.Placeholder(TFDataType.Float, new TFShape(784));
// 创建计算误差和预测的图
var distance = g.ReduceSum(g.Abs(g.Add(trainingInput, g.Neg(xte))), axis: g.Const(1));
var pred = g.ArgMin(distance, g.Const(0));
// 创建会话
var sess = new TFSession(g);
// 精度
var accuracy = 0.0f;
// 进行迭代训练,并且每次都输出预测值
for (int i = 0; i < testCount; i++)
{
var runner = sess.GetRunner();
// 计算并且获取误差和预测值
var result = runner.
Fetch(pred).
Fetch(distance).
AddInput(trainingInput, trainingImages).
AddInput(xte, Extract(testImages, i)).Run();
var r = result[0].GetValue();
var tr = result[1].GetValue();
var nn_index = (int)(long)result[0].GetValue();
this.textBox1.Text += string.Format($"训练次数 {i}: 预测: { ArgMax(trainingLabels, nn_index) } 真实值: { ArgMax(testLabels, i)} (nn_index= { nn_index })");
if (ArgMax(trainingLabels, nn_index) == ArgMax(testLabels, i))
accuracy += 1f / testImages.Length;
}
// 精确度
this.textBox1.Text += string.Format("精度:" + accuracy);
}
/// <summary>
/// 获取矩阵array中idx行的最大值
/// </summary>
/// <param name="array"></param>
/// <param name="idx"></param>
/// <returns></returns>
static int ArgMax(float[,] array, int idx)
{
float max = -1;
int maxIdx = -1;
var len = array.GetLength(1);
for (int i = 0; i < len; i++)
if (array[idx, i] > max)
{
maxIdx = i;
max = array[idx, i];
}
return maxIdx;
}
/// <summary>
/// 获取矩阵array中的index行(即获取n*n图片数组中的第n张)
/// </summary>
/// <param name="array"></param>
/// <param name="index"></param>
/// <returns></returns>
static public float[] Extract(float[,] array, int index)
{
var n = array.GetLength(1);
var ret = new float[n];
for (int i = 0; i < n; i++)
ret[i] = array[index, i];
return ret;
}
/// <summary>
/// 08 张量的使用
/// </summary>
private void button8_Click(object sender, EventArgs e)
{
// 整数张量
var tensor = new TFTensor(1);
this.textBox1.Text +="\r\n"+ string.Format("Value:" + tensor.GetValue());
// 矩阵张量
tensor = new TFTensor(new int[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 },
});
this.textBox1.Text += "\r\n" + string.Format("TensorType: {0}", tensor.TensorType.ToString());
this.textBox1.Text += "\r\n" + string.Format("NumDims: " + tensor.NumDims);
this.textBox1.Text += "\r\n" + string.Format("Shape", string.Join(",", tensor.Shape));
for (var i = 0; i < tensor.NumDims; i++)
{
var dim = tensor.GetTensorDimension(i);
this.textBox1.Text += "\r\n" + string.Format("DimIndex: {0}, Dim: {1}", i, dim);
}
// 创建图
var g = new TFGraph();
// 创建字符串张量
tensor = new TFTensor("Hello, world!".Select(o => (sbyte)o).ToArray());
var hello = g.Const(tensor);
// 创建会话
var sess = new TFSession(g);
// 进行计算
var result = sess.GetRunner().Run(hello).GetValue();
// 输出计算结果
this.textBox1.Text += "\r\n" + string.Format(string.Join("", ((sbyte[])result).Select(o => (char)o)));
}
以下三个类文件放到自己工程下,网上很多MNIST.cs的代码有些版本不同,会有部分函数功能不同,目前最新如下:
MNIST.cs代码
//
// Code to download and load the MNIST data.
//
using System;
using System.IO;
using System.IO.Compression;
using Mono;
using TensorFlow;
using System.Linq;
namespace Learn.Mnist
{
// Stores the per-image MNIST information we loaded from disk
//
// We store the data in two formats, byte array (as it came in from disk), and float array
// where each 0..255 value has been mapped to 0.0f..1.0f
public struct MnistImage
{
public int Cols, Rows;
public byte[] Data;
public float[] DataFloat;
public MnistImage(int cols, int rows, byte[] data)
{
Cols = cols;
Rows = rows;
Data = data;
DataFloat = new float[data.Length];
for (int i = 0; i < data.Length; i++)
{
DataFloat[i] = Data[i] / 255f;
}
}
}
// Helper class used to load and work with the Mnist data set
public class Mnist
{
//
// The loaded results
//
public MnistImage[] TrainImages, TestImages, ValidationImages;
public byte[] TrainLabels, TestLabels, ValidationLabels;
public byte[,] OneHotTrainLabels, OneHotTestLabels, OneHotValidationLabels;
public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);
public class BatchReader
{
int start = 0;
MnistImage[] source;
byte[] labels;
byte[,] oneHotLabels;
internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
{
this.source = source;
this.labels = labels;
this.oneHotLabels = oneHotLabels;
}
public void NextBatch(int batchSize, out float[,] imageData, out float[,] labelData)
{
imageData = new float[batchSize, 784];
labelData = new float[batchSize, 10];
int p = 0;
for (int item = 0; item < batchSize; item++)
{
Buffer.BlockCopy(source[start + item].DataFloat, 0, imageData, p, 784 * sizeof(float));
p += 784 * sizeof(float);
for (var j = 0; j < 10; j++)
labelData[item, j] = oneHotLabels[item + start, j];
}
start += batchSize;
}
}
int Read32(Stream s)
{
var x = new byte[4];
s.Read(x, 0, 4);
return DataConverter.BigEndian.GetInt32(x, 0);
}
MnistImage[] ExtractImages(Stream input, string file)
{
using (var gz = new GZipStream(input, CompressionMode.Decompress))
{
if (Read32(gz) != 2051)
throw new Exception("Invalid magic number found on the MNIST " + file);
var count = Read32(gz);
var rows = Read32(gz);
var cols = Read32(gz);
var result = new MnistImage[count];
for (int i = 0; i < count; i++)
{
var size = rows * cols;
var data = new byte[size];
gz.Read(data, 0, size);
result[i] = new MnistImage(cols, rows, data);
}
return result;
}
}
byte[] ExtractLabels(Stream input, string file)
{
using (var gz = new GZipStream(input, CompressionMode.Decompress))
{
if (Read32(gz) != 2049)
throw new Exception("Invalid magic number found on the MNIST " + file);
var count = Read32(gz);
var labels = new byte[count];
gz.Read(labels, 0, count);
return labels;
}
}
T[] Pick<T>(T[] source, int first, int last)
{
if (last == 0)
last = source.Length;
var count = last - first;
var result = new T[count];
Array.Copy(source, first, result, 0, count);
return result;
}
// Turn the labels array that contains values 0..numClasses-1 into
// a One-hot encoded array
byte[,] OneHot(byte[] labels, int numClasses)
{
var oneHot = new byte[labels.Length, numClasses];
for (int i = 0; i < labels.Length; i++)
{
oneHot[i, labels[i]] = 1;
}
return oneHot;
}
/// <summary>
/// Reads the data sets.
/// </summary>
/// <param name="trainDir">Directory where the training data is downlaoded to.</param>
/// <param name="numClasses">Number classes to use for one-hot encoding, or zero if this is not desired</param>
/// <param name="validationSize">Validation size.</param>
public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000)
{
const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
const string TrainImagesName = "train-images-idx3-ubyte.gz";
const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
const string TestImagesName = "t10k-images-idx3-ubyte.gz";
const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);
ValidationImages = Pick(TrainImages, 0, validationSize);
ValidationLabels = Pick(TrainLabels, 0, validationSize);
TrainImages = Pick(TrainImages, validationSize, 0);
TrainLabels = Pick(TrainLabels, validationSize, 0);
if (numClasses != -1)
{
OneHotTrainLabels = OneHot(TrainLabels, numClasses);
OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
OneHotTestLabels = OneHot(TestLabels, numClasses);
}
}
public static Mnist Load()
{
var x = new Mnist();
x.ReadDataSets(Environment.CurrentDirectory + "\\tmp");
return x;
}
}
}
DataConverter.cs文件
//
// Authors:
// Miguel de Icaza (miguel@novell.com)
//
// See the following url for documentation:
// http://www.mono-project.com/Mono_DataConvert
//
// Compilation Options:
// MONO_DATACONVERTER_PUBLIC:
// Makes the class public instead of the default internal.
//
// MONO_DATACONVERTER_STATIC_METHODS:
// Exposes the public static methods.
//
// TODO:
// Support for "DoubleWordsAreSwapped" for ARM devices
//
// Copyright (C) 2006 Novell, Inc (http://www.novell.com)
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software withou