前言
本文基于CNTK实现分类,并以之前的不同,本次使用C#实现,不适用python,python版的CNTK比较简单,而且python版的cntk个人感觉没什么必要,毕竟是微软的框架因此本人强迫症犯了,所以使用C#实现CNTK
环境版本:
Visualstudio 2022
C# .net4.6
cntk 2.7
cuda 10.1
一、数据集准备
本次数据集使用中国象棋数据集,如图:
DataImage_train:图像训练文件夹
DataImage_val:图像验证
test:图像测试
二、图像分类程序构建
1.变量定义
本次使用的变量都包含了训练,验证,测试的变量
//CNTK模型类
static CntkModelWrapper _modelWrapper = null;
//CNTK的输入层和输出层的名称
private const string FEATURE_STREAM_NAME = "features";
private const string LABEL_STREAM_NAME = "labels";
//预训练模型文件
static string BaseWorkPath = @"./Base_Model";
//是否启用GPU
static bool useGPU = true;
static DeviceDescriptor device = useGPU ? DeviceDescriptor.GPUDevice(0) : DeviceDescriptor.CPUDevice;
//训练使用的模型及拥有的模型列表
static int model_type = 0;
static string[] base_model_file = new string[] { "AlexNet_ImageNet_CNTK.model", "InceptionV3_ImageNet_CNTK.model", "ResNet18_ImageNet_CNTK.model", "ResNet34_ImageNet_CNTK.model", "ResNet50_ImageNet_CNTK.model", "ResNet101_ImageNet_CNTK.model", "ResNet152_ImageNet_CNTK.model" };
//训练图像的大小
static int IMAGE_WIDTH = (model_type == 0) ? 227 : ((model_type == 1) ? 299 : 224);
static int IMAGE_HEIGHT = (model_type == 0) ? 227 : ((model_type == 1) ? 299 : 224);
static int IMAGE_DEPTH = 3;
//学习率
static float learning_rate = 0.0001f;
//每次迭代的批次
static uint batch_size = 4;
//训练次数
static int TrainNum = 300;
//类别,这里通过读取文件夹
static string[] classes_names = new string[]{ };
//是否重新生成训练文件及重新构建模型
static bool reCreateData = false;
static bool reCreateModel = true;
//训练数据集的文件夹
static string ImageDir_Train = @"./DataSet_Classification_Chess\DataImage_train";
//验证数据集的文件夹
static string ImageDir_Val = @"./DataSet_Classification_Chess\DataImage_val";
//图像测试文件夹
static string ImageDir_Test = @"./DataSet_Classification_Chess\test";
//图像扩展名
static string ext = "bmp";
//最后模型保存路径
static string model_path = "./result_Model";
//模型训练生成名称
static string model_file = "Ctu_CNTK.model";
//保存列表,这里保存的是类别名称及顺序
static string config_file = "Ctu_Config.txt";
//保存训练文件
static string train_data_file = "train-dataset.txt";
//保存验证文件
static string test_data_file = "test-dataset.txt";
//本程序运行模式
static string RunModel = "train";
2.模型文件生成
由于预训练模型文件是针对ImageNet数据集分1000类,因为是自定义数据集,所以需要对后面的模型进行小修改
public static Function BuildTransferLearningModel(Function baseModel, string featureNodeName, string outputNodeName, string hiddenNodeName, int[] imageDims, int numClasses, DeviceDescriptor device)
{
var input = Variable.InputVariable(imageDims, DataType.Float);
var normalizedFeatureNode = CNTKLib.Minus(input, Constant.Scalar(DataType.Float, 114.0F));
var oldFeatureNode = baseModel.Arguments.Single(a => a.Name == featureNodeName);
var lastNode = baseModel.FindByName(hiddenNodeName);
var clonedLayer = CNTKLib.AsComposite(lastNode).Clone(
ParameterCloningMethod.Freeze,
new Dictionary<Variable, Variable>()
{
{ oldFeatureNode, normalizedFeatureNode }
});
var clonedModel = Dense(clonedLayer, numClasses, device, Activation.None, outputNodeName);
return clonedModel;
}
if (reCreateModel || File.Exists(Path.Combine(model_path, model_file)) == false)
{
CreateAndSaveModel(config, Path.Combine(BaseWorkPath, base_model_file[model_type]), Path.Combine(model_path, model_file), device);
}
3.训练数据集生成
if (reCreateData || Directory.Exists(model_path) == false)
{
DirectoryInfo dir = new DirectoryInfo(model_path);
if (dir.Exists)
{
DirectoryInfo[] childs = dir.GetDirectories();
foreach (DirectoryInfo child in childs)
{
child.Delete(true);
}
dir.Delete(true);
}
Directory.CreateDirectory(model_path);
CreateAndSaveDatasets(config, ImageDir_Train, Path.Combine(model_path, train_data_file), ImageDir_Val, Path.Combine(model_path, test_data_file), ext);
}
4.训练完整代码
if (RunModel == "train")
{
classes_names = new DirectoryInfo(ImageDir_Train).GetDirectories().Select(d => d.Name).ToList().ToArray();
//classes_names = Directory.GetDirectories(ImageDir_Train).Select(d => d.Substring(d.LastIndexOf('\\') + 1)).ToList().ToArray();
var config = new ClassificationConfig(classes_names);
if (reCreateData || Directory.Exists(model_path) == false)
{
DirectoryInfo dir = new DirectoryInfo(model_path);
if (dir.Exists)
{
DirectoryInfo[] childs = dir.GetDirectories();
foreach (DirectoryInfo child in childs)
{
child.Delete(true);
}
dir.Delete(true);
}
Directory.CreateDirectory(model_path);
CreateAndSaveDatasets(config, ImageDir_Train, Path.Combine(model_path, train_data_file), ImageDir_Val, Path.Combine(model_path, test_data_file), ext);
}
if (reCreateModel || File.Exists(Path.Combine(model_path, model_file)) == false)
{
CreateAndSaveModel(config, Path.Combine(BaseWorkPath, base_model_file[model_type]), Path.Combine(model_path, model_file), device);
}
config.Save(Path.Combine(model_path, config_file));
// 训练
_modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
var dataSource = CreateDataSource(Path.Combine(model_path, train_data_file));
var trainer = CreateTrainer();
var minibatchesSeen = 0;
int data_length = readFileLines(Path.Combine(model_path, train_data_file));
while (true)
{
var minibatchData = dataSource.MinibatchSource.GetNextMinibatch(batch_size, device);
var arguments = new Dictionary<Variable, MinibatchData>
{
{ _modelWrapper.Input, minibatchData[dataSource.FeatureStreamInfo] },
{ _modelWrapper.TrainingOutput, minibatchData[dataSource.LabelStreamInfo] }
};
trainer.TrainMinibatch(arguments, device);
double loss = trainer.PreviousMinibatchLossAverage();
double eval = trainer.PreviousMinibatchEvaluationAverage();
int epoch = Convert.ToInt32((minibatchesSeen * batch_size / data_length)) + 1;
Console.WriteLine($"[{epoch}:{TrainNum}/{minibatchesSeen % (data_length / batch_size)+1}] CrossEntropyLoss = {loss}, EvaluationCriterion = {eval}");
minibatchesSeen++;
if ((Convert.ToInt32((minibatchesSeen * batch_size / data_length)) + 1) > TrainNum)
{
_modelWrapper.Model.Save(Path.Combine(model_path, model_file));
break;
}
}
RunModel = "val";
}
5.验证完整代码
if (RunModel == "val")
{
//验证
_modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
var dataSource_test = CreateDataSource(Path.Combine(model_path, test_data_file));
const int minibatchSize = 1;
var currentMinibatch = 0;
int Correct = 0;
int Total = 0;
while (true)
{
var minibatchData = dataSource_test.MinibatchSource.GetNextMinibatch(minibatchSize, device);
var inputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.Input, minibatchData[dataSource_test.FeatureStreamInfo].data } };
var outputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.EvaluationOutput, null } };
_modelWrapper.Model.Evaluate(inputDataMap, outputDataMap, device);
var outputVal = outputDataMap[_modelWrapper.EvaluationOutput];
var actual = outputVal.GetDenseData<float>(_modelWrapper.EvaluationOutput);
var labelBatch = minibatchData[dataSource_test.LabelStreamInfo].data;
var expected = labelBatch.GetDenseData<float>(_modelWrapper.Model.Output);
Func<IEnumerable<IList<float>>, IEnumerable<int>> maxSelector =
(collection) => collection.Select(x => x.IndexOf(x.Max()));
var actualLabels = maxSelector(actual);
var expectedLabels = maxSelector(expected);
Correct += actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 1 : 0).Sum();
Total += actualLabels.Count();
double acc = (Convert.ToDouble(Correct) / Total);
Console.WriteLine($"Correct = {Correct}, Total = {Total}, acc={acc}");
currentMinibatch++;
if (minibatchData.Values.Any(x => x.sweepEnd))
{
break;
}
}
}
6.预测完整代码
if(RunModel=="test")
{
//测试图片
_modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
var config = ClassificationConfig.Load(Path.Combine(model_path, config_file));
string[] all_image = Directory.GetFiles(ImageDir_Test, $"*.{ext}");
foreach(string file in all_image)
{
var inputValue = new Value(new NDArrayView(new int[] { IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_DEPTH }, ImageHelper.Load(IMAGE_WIDTH, IMAGE_HEIGHT, file), device));
var inputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.Input, inputValue } };
var outputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.EvaluationOutput, null } };
_modelWrapper.Model.Evaluate(inputDataMap, outputDataMap, device);
var outputData = outputDataMap[_modelWrapper.EvaluationOutput].GetDenseData<float>(_modelWrapper.EvaluationOutput).First();
var output = outputData.Select(x => (double)x).ToArray();
var classIndex = Array.IndexOf(output, output.Max());
var className = config.GetClassNameByIndex(classIndex);
Console.WriteLine(file + " : " + className);
}
}
训练效果
预测效果
总结
源码私聊