基于CNTK/C#实现Cifar【附源码】


前言

本文讲解基于CNTK+C#实现Cifar功能,并且如何构建ResNet网络结构


一、数据集的构建

首先这里提供一个1.tar.gz和一个py文件,这里主要是构建数据集,使用者只需要运行python文件就可以得到如下文件夹,此文件夹就是cifar的训练数据集
python实现代码如下:

from PIL import Image
import getopt, os, shutil, struct, sys,tarfile
import numpy as np
import pickle as cp
import xml.etree.cElementTree as et
import xml.dom.minidom

imgSize = 32
numFeature = imgSize * imgSize * 3

data_dir = './CIFAR-10/'
train_filename = data_dir + '/Train_cntk_text.txt'
test_filename = data_dir + '/Test_cntk_text.txt'
train_img_directory = data_dir + '/Train'
test_img_directory = data_dir + '/Test'

def readBatch(src):
    with open(src, 'rb') as f:
        if sys.version_info[0] < 3: 
            d = cp.load(f) 
        else:
            d = cp.load(f, encoding='latin1')
        data = d['data']
        feat = data
    res = np.hstack((feat, np.reshape(d['labels'], (len(d['labels']), 1))))
    return res.astype(np.int)

def saveTxt(filename, ndarray):
    with open(filename, 'w') as f:
        labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str)))
        for row in ndarray:
            row_str = row.astype(str)
            label_str = labels[row[-1]]
            feature_str = ' '.join(row_str[:-1])
            f.write('|labels {} |features {}\n'.format(label_str, feature_str))

def saveMean(fname, data):
    root = et.Element('opencv_storage')
    et.SubElement(root, 'Channel').text = '3'
    et.SubElement(root, 'Row').text = str(imgSize)
    et.SubElement(root, 'Col').text = str(imgSize)
    meanImg = et.SubElement(root, 'MeanImg', type_id='opencv-matrix')
    et.SubElement(meanImg, 'rows').text = '1'
    et.SubElement(meanImg, 'cols').text = str(imgSize * imgSize * 3)
    et.SubElement(meanImg, 'dt').text = 'f'
    et.SubElement(meanImg, 'data').text = ' '.join(['%e' % n for n in np.reshape(data, (imgSize * imgSize * 3))])

    tree = et.ElementTree(root)
    tree.write(fname)
    x = xml.dom.minidom.parse(fname)
    with open(fname, 'w') as f:
        f.write(x.toprettyxml(indent = '  '))

def saveImage(fname, data, label, mapFile, regrFile, pad, **key_parms):
    # data in CIFAR-10 dataset is in CHW format.
    pixData = data.reshape((3, imgSize, imgSize))
    if ('mean' in key_parms):
        key_parms['mean'] += pixData

    if pad > 0:
        pixData = np.pad(pixData, ((0, 0), (pad, pad), (pad, pad)), mode='constant', constant_values=128) 

    img = Image.new('RGB', (imgSize + 2 * pad, imgSize + 2 * pad))
    pixels = img.load()
    for x in range(img.size[0]):
        for y in range(img.size[1]):
            pixels[x, y] = (pixData[0][y][x], pixData[1][y][x], pixData[2][y][x])
    img.save(fname)
    mapFile.write("%s\t%d\n" % (fname, label))

    channelMean = np.mean(pixData, axis=(1,2))
    regrFile.write("|regrLabels\t%f\t%f\t%f\n" % (channelMean[0]/255.0, channelMean[1]/255.0, channelMean[2]/255.0))

def saveTrainImages(filename, foldername):
    if not os.path.exists(foldername):
        os.makedirs(foldername)
    data = {}
    dataMean = np.zeros((3, imgSize, imgSize)) 
    with open(os.path.join(data_dir,'train_map.txt'), 'w') as mapFile:
        with open(os.path.join(data_dir,'train_regrLabels.txt'), 'w') as regrFile:
            for ifile in range(1, 6):
                with open(os.path.join('./cifar-10-batches-py', 'data_batch_' + str(ifile)), 'rb') as f:
                    if sys.version_info[0] < 3: 
                        data = cp.load(f)
                    else: 
                        data = cp.load(f, encoding='latin1')
                    for i in range(10000):
                        fname = os.path.join(os.path.abspath(foldername), ('%05d.png' % (i + (ifile - 1) * 10000)))
                        saveImage(fname, data['data'][i, :], data['labels'][i], mapFile, regrFile, 4, mean=dataMean)
    dataMean = dataMean / (50 * 1000)
    saveMean(os.path.join(data_dir,'CIFAR-10_mean.xml'), dataMean)

def saveTestImages(filename, foldername):
    if not os.path.exists(foldername):
      os.makedirs(foldername)
    with open(os.path.join(data_dir,'test_map.txt'), 'w') as mapFile:
        with open(os.path.join(data_dir,'test_regrLabels.txt'), 'w') as regrFile:
            with open(os.path.join('./cifar-10-batches-py', 'test_batch'), 'rb') as f:
                if sys.version_info[0] < 3: 
                    data = cp.load(f)
                else: 
                    data = cp.load(f, encoding='latin1')
                for i in range(10000):
                    fname = os.path.join(os.path.abspath(foldername), ('%05d.png' % i))
                    saveImage(fname, data['data'][i, :], data['labels'][i], mapFile, regrFile, 0)

# 解压压缩包
with tarfile.open(r'1.tar.gz') as tar:
    tar.extractall()
trn = np.empty((0, numFeature + 1), dtype=np.int)
for i in range(5):
    batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1)
    trn = np.vstack((trn, readBatch(batchName)))
tst = readBatch('./cifar-10-batches-py/test_batch')

if not os.path.exists(data_dir):
    os.makedirs(data_dir)
os.makedirs(train_img_directory)
os.makedirs(test_img_directory)

saveTxt(os.path.join(data_dir,'Train_cntk_text.txt'), trn)
saveTxt(os.path.join(data_dir,'Test_cntk_text.txt'), tst)

saveTrainImages(os.path.join(data_dir,'Train_cntk_text.txt'), train_img_directory )
saveTestImages(os.path.join(data_dir,'Test_cntk_text.txt'), test_img_directory )

如图:
在这里插入图片描述
最后得到的目录结构如下:
在这里插入图片描述


二、CNTK/C#->Cifar代码构建

1.首先定义标准变量

//定义模型文件及训练次数
string modelFile = "./Cifar10Rest.model";
uint MaxEpochs = 120;

2.定义GPU环境

//定义GPU的配置
var device = DeviceDescriptor.GPUDevice(0);

3.载入数据集

//载入训练数据
var minibatchSource = CreateMinibatchSource(@".\CIFAR-10\train_map.txt", @".\CIFAR-10\CIFAR-10_mean.xml", imageDim, numClasses, MaxEpochs);
var imageStreamInfo = minibatchSource.StreamInfo("features");
var labelStreamInfo = minibatchSource.StreamInfo("labels");

4.创建模型

 //创建模型,先定义输入输出,然后定义模型,这里使用ResNet模型
var imageInput = CNTKLib.InputVariable(imageDim, imageStreamInfo.m_elementType, "Images");
var labelsVar = CNTKLib.InputVariable(new int[] { numClasses }, labelStreamInfo.m_elementType, "Labels");
var classifierOutput = ResNetClassifier(imageInput, numClasses, device, "classifierOutput");

模型结构具体实现代码

private static Function ConvBatchNormalizationLayer(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, int hStride, int vStride, double wScale, double bValue, double scValue, int bnTimeConst, bool spatial, DeviceDescriptor device)
{
    int numInputChannels = input.Shape[input.Shape.Rank - 1];

    var convParams = new Parameter(new int[] { kernelWidth, kernelHeight, numInputChannels, outFeatureMapCount },
        DataType.Float, CNTKLib.GlorotUniformInitializer(wScale, -1, 2), device);
    var convFunction = CNTKLib.Convolution(convParams, input, new int[] { hStride, vStride, numInputChannels });

    var biasParams = new Parameter(new int[] { NDShape.InferredDimension }, (float)bValue, device, "");
    var scaleParams = new Parameter(new int[] { NDShape.InferredDimension }, (float)scValue, device, "");
    var runningMean = new Constant(new int[] { NDShape.InferredDimension }, 0.0f, device);
    var runningInvStd = new Constant(new int[] { NDShape.InferredDimension }, 0.0f, device);
    var runningCount = Constant.Scalar(0.0f, device);
    return CNTKLib.BatchNormalization(convFunction, scaleParams, biasParams, runningMean, runningInvStd, runningCount,
        spatial, (double)bnTimeConst, 0.0, 1e-5 /* epsilon */);
}
private static Function ConvBatchNormalizationReLULayer(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, int hStride, int vStride, double wScale, double bValue, double scValue, int bnTimeConst, bool spatial, DeviceDescriptor device)
{
    var convBNFunction = ConvBatchNormalizationLayer(input, outFeatureMapCount, kernelWidth, kernelHeight, hStride, vStride, wScale, bValue, scValue, bnTimeConst, spatial, device);
    return CNTKLib.ReLU(convBNFunction);
}
private static Function ResNetNode(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, double wScale, double bValue,double scValue, int bnTimeConst, bool spatial, DeviceDescriptor device)
{
    var c1 = ConvBatchNormalizationReLULayer(input, outFeatureMapCount, kernelWidth, kernelHeight, 1, 1, wScale, bValue, scValue, bnTimeConst, spatial, device);
    var c2 = ConvBatchNormalizationLayer(c1, outFeatureMapCount, kernelWidth, kernelHeight, 1, 1, wScale, bValue, scValue, bnTimeConst, spatial, device);
    var p = CNTKLib.Plus(c2, input);
    return CNTKLib.ReLU(p);
}
private static Function ProjectLayer(Variable wProj, Variable input, int hStride, int vStride, double bValue, double scValue, int bnTimeConst, DeviceDescriptor device)
{
    int outFeatureMapCount = wProj.Shape[0];
    var b = new Parameter(new int[] { outFeatureMapCount }, (float)bValue, device, "");
    var sc = new Parameter(new int[] { outFeatureMapCount }, (float)scValue, device, "");
    var m = new Constant(new int[] { outFeatureMapCount }, 0.0f, device);
    var v = new Constant(new int[] { outFeatureMapCount }, 0.0f, device);

    var n = Constant.Scalar(0.0f, device);

    int numInputChannels = input.Shape[input.Shape.Rank - 1];

    var c = CNTKLib.Convolution(wProj, input, new int[] { hStride, vStride, numInputChannels }, new bool[] { true }, new bool[] { false });
    return CNTKLib.BatchNormalization(c, sc, b, m, v, n, true /*spatial*/, (double)bnTimeConst, 0, 1e-5, false);
}
private static Function ResNetNodeInc(Variable input, int outFeatureMapCount, int kernelWidth, int kernelHeight, double wScale, double bValue, double scValue, int bnTimeConst, bool spatial, Variable wProj, DeviceDescriptor device)
{
    var c1 = ConvBatchNormalizationReLULayer(input, outFeatureMapCount, kernelWidth, kernelHeight, 2, 2, wScale, bValue, scValue, bnTimeConst, spatial, device);
    var c2 = ConvBatchNormalizationLayer(c1, outFeatureMapCount, kernelWidth, kernelHeight, 1, 1, wScale, bValue, scValue, bnTimeConst, spatial, device);

    var cProj = ProjectLayer(wProj, input, 2, 2, bValue, scValue, bnTimeConst, device);

    var p = CNTKLib.Plus(c2, cProj);
    return CNTKLib.ReLU(p);
}
private static Constant GetProjectionMap(int outputDim, int inputDim, DeviceDescriptor device)
{
    if (inputDim > outputDim)
        throw new Exception("Can only project from lower to higher dimensionality");

    float[] projectionMapValues = new float[inputDim * outputDim];
    for (int i = 0; i < inputDim * outputDim; i++)
        projectionMapValues[i] = 0;
    for (int i = 0; i < inputDim; ++i)
        projectionMapValues[(i * (int)inputDim) + i] = 1.0f;

    var projectionMap = new NDArrayView(DataType.Float, new int[] { 1, 1, inputDim, outputDim }, device);
    projectionMap.CopyFrom(new NDArrayView(new int[] { 1, 1, inputDim, outputDim }, projectionMapValues, (uint)projectionMapValues.Count(), device));

    return new Constant(projectionMap);
}
private static Function ResNetClassifier(Variable input, int numOutputClasses, DeviceDescriptor device, string outputName)
{
    //模型构建
    double convWScale = 7.07;
    double convBValue = 0;

    double fc1WScale = 0.4;
    double fc1BValue = 0;

    double scValue = 1;
    int bnTimeConst = 4096;

    int kernelWidth = 3;
    int kernelHeight = 3;

    double conv1WScale = 0.26;
    int cMap1 = 16;
    //卷积+标准化+激活函数
    var conv1 = ConvBatchNormalizationReLULayer(input, cMap1, kernelWidth, kernelHeight, 1, 1, conv1WScale, convBValue, scValue, bnTimeConst, true /*spatial*/, device);

    //ResNet实现的块
    var rn1_1 = ResNetNode(conv1, cMap1, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
    var rn1_2 = ResNetNode(rn1_1, cMap1, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, device);
    var rn1_3 = ResNetNode(rn1_2, cMap1, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);

    int cMap2 = 32;
    var rn2_1_wProj = GetProjectionMap(cMap2, cMap1, device);
    var rn2_1 = ResNetNodeInc(rn1_3, cMap2, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, rn2_1_wProj, device);
    var rn2_2 = ResNetNode(rn2_1, cMap2, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
    var rn2_3 = ResNetNode(rn2_2, cMap2, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, device);

    int cMap3 = 64;
    var rn3_1_wProj = GetProjectionMap(cMap3, cMap2, device);
    var rn3_1 = ResNetNodeInc(rn2_3, cMap3, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, true /*spatial*/, rn3_1_wProj, device);
    var rn3_2 = ResNetNode(rn3_1, cMap3, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);
    var rn3_3 = ResNetNode(rn3_2, cMap3, kernelWidth, kernelHeight, convWScale, convBValue, scValue, bnTimeConst, false /*spatial*/, device);

    // 全局平均池化
    int poolW = 8;
    int poolH = 8;
    int poolhStride = 1;
    int poolvStride = 1;
    var pool = CNTKLib.Pooling(rn3_3, PoolingType.Average,
        new int[] { poolW, poolH, 1 }, new int[] { poolhStride, poolvStride, 1 });

    // 输出层
    var outTimesParams = new Parameter(new int[] { numOutputClasses, 1, 1, cMap3 }, DataType.Float,
        CNTKLib.GlorotUniformInitializer(fc1WScale, 1, 0), device);
    var outBiasParams = new Parameter(new int[] { numOutputClasses }, (float)fc1BValue, device, "");

    return CNTKLib.Plus(CNTKLib.Times(outTimesParams, pool), outBiasParams, outputName);
}

5.训练指标

//训练及验证指标,loss及acc
var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labelsVar, "lossFunction");
var prediction = CNTKLib.ClassificationError(classifierOutput, labelsVar, 5, "predictionError");   //top5

6.学习率设置

var learningRatePerSample = new TrainingParameterScheduleDouble(0.0078125, 1);

7.训练器

var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction,
                new List<Learner> { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) });

8.开始训练

while (true)
{
    //获取每次批次的数据
    var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);

    if (minibatchData.empty())
    {
        break;
    }

    //训练批次
    trainer.TrainMinibatch(new Dictionary<Variable, MinibatchData>()
            { { imageInput, minibatchData[imageStreamInfo] }, { labelsVar, minibatchData[labelStreamInfo] } }, device);

    //输出训练结果
    PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches);
}

9.验证集合

public static float ValidateModelWithMinibatchSource(string modelFile, MinibatchSource testMinibatchSource, int[] imageDim, int numClasses, string featureInputName, string labelInputName, string outputName, DeviceDescriptor device, int maxCount = 1000)
{
    Function model = Function.Load(modelFile, device);
    var imageInput = model.Arguments[0];
    var labelOutput = model.Outputs.Single(o => o.Name == outputName);

    var featureStreamInfo = testMinibatchSource.StreamInfo(featureInputName);
    var labelStreamInfo = testMinibatchSource.StreamInfo(labelInputName);

    int batchSize = 50;
    int miscountTotal = 0, totalCount = 0;
    while (true)
    {
        var minibatchData = testMinibatchSource.GetNextMinibatch((uint)batchSize, device);
        if (minibatchData == null || minibatchData.Count == 0)
            break;
        totalCount += (int)minibatchData[featureStreamInfo].numberOfSamples;

        var labelData = minibatchData[labelStreamInfo].data.GetDenseData<float>(labelOutput);
        var expectedLabels = labelData.Select(l => l.IndexOf(l.Max())).ToList();

        var inputDataMap = new Dictionary<Variable, Value>() {
            { imageInput, minibatchData[featureStreamInfo].data }
        };

        var outputDataMap = new Dictionary<Variable, Value>() {
            { labelOutput, null }
        };

        model.Evaluate(inputDataMap, outputDataMap, device);
        var outputData = outputDataMap[labelOutput].GetDenseData<float>(labelOutput);
        var actualLabels = outputData.Select(l => l.IndexOf(l.Max())).ToList();

        int misMatches = actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 0 : 1).Sum();

        miscountTotal += misMatches;
        Console.WriteLine($"Validating Model: Total Samples = {totalCount}, Misclassify Count = {miscountTotal}");

        if (totalCount > maxCount)
            break;
    }

    float errorRate = 1.0F * miscountTotal / totalCount;
    Console.WriteLine($"Model Validation Error = {errorRate}");
    return errorRate;
}

三、训练展示

在这里插入图片描述
在这里插入图片描述


源码

源码下载

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱学习的广东仔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值