文章目录
前言
本文实现基于CNTK实现MNIST,并对其实现CNN以及MLP方法进行测试
一、环境搭建
因为本次使用C#,我们只需要对C#的一些依赖库进行整理放入到项目中即可,这里我把所需要的所有的dll都整理到win64文件夹,在C# 的项目工程上只需要引入Cntk.Core.Managed-2.7.dll、netstandard.dll即可。
如图:
二、MNIST代码解析
这里我们直接上代码并解释每行代码的意思和构建的思路和步骤
1.GPU/CPU的设置
//GPU设置
var device = DeviceDescriptor.GPUDevice(0);
2.参数变量的设置
//使用cnn还是使用mlp
bool useConvolution = true;
//模型内的名称
var featureStreamName = "features";
var labelsStreamName = "labels";
var classifierName = "classifierOutput";
Function classifierOutput;
//输入和输出定义
int[] imageDim = useConvolution ? new int[] { 28, 28, 1 } : new int[] { 784 };
int imageSize = 28 * 28;
int numClasses = 10;
//设置模型的保存名称
string modelFile = useConvolution ? "./MNIST_CNN.model" : "./MNIST_MLP.model";
3.关联模型构造
//配置网络层的对应关系:features -> 28 * 28 labelsStreamName -> numClasses
IList<StreamConfiguration> streamConfigurations = new StreamConfiguration[]
{ new StreamConfiguration(featureStreamName, imageSize), new StreamConfiguration(labelsStreamName, numClasses) };
//定义输入变量,输出变量
var input = CNTKLib.InputVariable(imageDim, DataType.Float, featureStreamName);
var labels = CNTKLib.InputVariable(new int[] { numClasses }, DataType.Float, labelsStreamName);
4.模型构建
if (useConvolution) //cnn
{
var scaledInput = CNTKLib.ElementTimes(Constant.Scalar<float>(0.00390625f, device), input);
//构建CNN网络结构
classifierOutput = CreateConvolutionalNeuralNetwork(scaledInput, numClasses, device, classifierName);
}
else //mlp
{
int hiddenLayerDim = 200; //mlp的隐藏节点
//构建MLP结构
var scaledInput = CNTKLib.ElementTimes(Constant.Scalar<float>(0.00390625f, device), input);
classifierOutput = CreateMLPClassifier(device, numClasses, hiddenLayerDim, scaledInput, classifierName);
}
static Function CreateConvolutionalNeuralNetwork(Variable features, int outDims, DeviceDescriptor device, string classifierName)
{
//CNN网络结构的构建
//初始化卷积层的参数
int kernelWidth1 = 3, kernelHeight1 = 3, numInputChannels1 = 1, outFeatureMapCount1 = 4;
int hStride1 = 2, vStride1 = 2;
int poolingWindowWidth1 = 3, poolingWindowHeight1 = 3;
// 28x28x1 -> 14x14x4 卷积+激活函数
Function pooling1 = ConvolutionWithMaxPooling(features, device, kernelWidth1, kernelHeight1,
numInputChannels1, outFeatureMapCount1, hStride1, vStride1, poolingWindowWidth1, poolingWindowHeight1);
//初始化卷积层的参数
int kernelWidth2 = 3, kernelHeight2 = 3, numInputChannels2 = outFeatureMapCount1, outFeatureMapCount2 = 8;
int hStride2 = 2, vStride2 = 2;
int poolingWindowWidth2 = 3, poolingWindowHeight2 = 3;
// 14x14x4 -> 7x7x8 卷积+激活函数
Function pooling2 = ConvolutionWithMaxPooling(pooling1, device, kernelWidth2, kernelHeight2,
numInputChannels2, outFeatureMapCount2, hStride2, vStride2, poolingWindowWidth2, poolingWindowHeight2);
//Dense层设计
Function denseLayer = Dense(pooling2, outDims, device, Activation.None, classifierName);
return denseLayer;
}
5.评价指标
//出模型后求损失
var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction");
//出模型后求准确率
var prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError");
6.数据集加载
//读取文件,数据集的读取,并确定好输入输出
var minibatchSource = MinibatchSource.TextFormatMinibatchSource("./mnist_data/MNIST_Train_cntk_text.txt", streamConfigurations, MinibatchSource.InfinitelyRepeat);
var featureStreamInfo = minibatchSource.StreamInfo(featureStreamName);
var labelStreamInfo = minibatchSource.StreamInfo(labelsStreamName);
7.学习率设置
//学习率的设置:SGD
TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.003125, 1);
IList<Learner> parameterLearners = new List<Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) };
8.获取模型训练器
var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners);
9.模型训练
//开始训练
const uint minibatchSize = 64;
int outputFrequencyInMinibatches = 20, i = 0;
int epochs = 10000;
while (epochs > 0)
{
var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);
//定义好输入数据
var arguments = new Dictionary<Variable, MinibatchData>
{
{ input, minibatchData[featureStreamInfo] },
{ labels, minibatchData[labelStreamInfo] }
};
//训练入口
trainer.TrainMinibatch(arguments, device);
//输出训练结果
PrintTrainingProgress(trainer, i++, outputFrequencyInMinibatches);
if (MiniBatchDataIsSweepEnd(minibatchData.Values))
{
epochs--;
}
}
10.模型保存
classifierOutput.Save(modelFile);
11.模型验证
// 验证模型
var minibatchSourceNewModel = MinibatchSource.TextFormatMinibatchSource("./mnist_data/MNIST_Test_cntk_text.txt", streamConfigurations, MinibatchSource.FullDataSweep);
ValidateModelWithMinibatchSource(modelFile, minibatchSourceNewModel, imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device);