基于CNTK/C#实现图像分类【附部分源码及模型】


前言

本文基于CNTK实现分类,并以之前的不同,本次使用C#实现,不适用python,python版的CNTK比较简单,而且python版的cntk个人感觉没什么必要,毕竟是微软的框架因此本人强迫症犯了,所以使用C#实现CNTK
环境版本:
Visualstudio 2022
C# .net4.6
cntk 2.7
cuda 10.1


一、数据集准备

本次数据集使用中国象棋数据集,如图:
在这里插入图片描述
在这里插入图片描述
DataImage_train:图像训练文件夹
DataImage_val:图像验证
test:图像测试


二、图像分类程序构建

1.变量定义

本次使用的变量都包含了训练,验证,测试的变量

//CNTK模型类
static CntkModelWrapper _modelWrapper = null;
//CNTK的输入层和输出层的名称
private const string FEATURE_STREAM_NAME = "features";
private const string LABEL_STREAM_NAME = "labels";
//预训练模型文件
static string BaseWorkPath = @"./Base_Model";
//是否启用GPU
static bool useGPU = true;
static DeviceDescriptor device = useGPU ? DeviceDescriptor.GPUDevice(0) : DeviceDescriptor.CPUDevice;
//训练使用的模型及拥有的模型列表
static int model_type = 0;
static string[] base_model_file = new string[] { "AlexNet_ImageNet_CNTK.model", "InceptionV3_ImageNet_CNTK.model", "ResNet18_ImageNet_CNTK.model", "ResNet34_ImageNet_CNTK.model", "ResNet50_ImageNet_CNTK.model", "ResNet101_ImageNet_CNTK.model", "ResNet152_ImageNet_CNTK.model" };
//训练图像的大小
static int IMAGE_WIDTH = (model_type == 0) ? 227 : ((model_type == 1) ? 299 : 224);
static int IMAGE_HEIGHT = (model_type == 0) ? 227 : ((model_type == 1) ? 299 : 224);
static int IMAGE_DEPTH = 3;
//学习率
static float learning_rate = 0.0001f;
//每次迭代的批次
static uint batch_size = 4;
//训练次数
static int TrainNum = 300;

//类别,这里通过读取文件夹
static string[] classes_names = new string[]{ };
//是否重新生成训练文件及重新构建模型
static bool reCreateData = false;
static bool reCreateModel = true;
//训练数据集的文件夹
static string ImageDir_Train = @"./DataSet_Classification_Chess\DataImage_train";
//验证数据集的文件夹
static string ImageDir_Val = @"./DataSet_Classification_Chess\DataImage_val";
//图像测试文件夹
static string ImageDir_Test = @"./DataSet_Classification_Chess\test";
//图像扩展名
static string ext = "bmp";
//最后模型保存路径
static string model_path = "./result_Model";
//模型训练生成名称
static string model_file = "Ctu_CNTK.model";
//保存列表,这里保存的是类别名称及顺序
static string config_file = "Ctu_Config.txt";
//保存训练文件
static string train_data_file = "train-dataset.txt";
//保存验证文件
static string test_data_file = "test-dataset.txt";

//本程序运行模式
static string RunModel = "train";

2.模型文件生成

由于预训练模型文件是针对ImageNet数据集分1000类,因为是自定义数据集,所以需要对后面的模型进行小修改

 public static Function BuildTransferLearningModel(Function baseModel, string featureNodeName, string outputNodeName, string hiddenNodeName, int[] imageDims, int numClasses, DeviceDescriptor device)
 {
     var input = Variable.InputVariable(imageDims, DataType.Float);
     var normalizedFeatureNode = CNTKLib.Minus(input, Constant.Scalar(DataType.Float, 114.0F));

     var oldFeatureNode = baseModel.Arguments.Single(a => a.Name == featureNodeName);
     var lastNode = baseModel.FindByName(hiddenNodeName);

     var clonedLayer = CNTKLib.AsComposite(lastNode).Clone(
         ParameterCloningMethod.Freeze,
         new Dictionary<Variable, Variable>()
         {
             { oldFeatureNode, normalizedFeatureNode }
         });

     var clonedModel = Dense(clonedLayer, numClasses, device, Activation.None, outputNodeName);
     return clonedModel;
 }
if (reCreateModel || File.Exists(Path.Combine(model_path, model_file)) == false)
{
     CreateAndSaveModel(config, Path.Combine(BaseWorkPath, base_model_file[model_type]), Path.Combine(model_path, model_file), device);
}

3.训练数据集生成

if (reCreateData || Directory.Exists(model_path) == false)
{
      DirectoryInfo dir = new DirectoryInfo(model_path);
      if (dir.Exists)
      {
          DirectoryInfo[] childs = dir.GetDirectories();
          foreach (DirectoryInfo child in childs)
          {
              child.Delete(true);
          }
          dir.Delete(true);
      }
      Directory.CreateDirectory(model_path);
      CreateAndSaveDatasets(config, ImageDir_Train, Path.Combine(model_path, train_data_file), ImageDir_Val, Path.Combine(model_path, test_data_file), ext);
}

在这里插入图片描述


4.训练完整代码

if (RunModel == "train")
{
    classes_names = new DirectoryInfo(ImageDir_Train).GetDirectories().Select(d => d.Name).ToList().ToArray();
    //classes_names = Directory.GetDirectories(ImageDir_Train).Select(d => d.Substring(d.LastIndexOf('\\') + 1)).ToList().ToArray();

    var config = new ClassificationConfig(classes_names);
    if (reCreateData || Directory.Exists(model_path) == false)
    {
        DirectoryInfo dir = new DirectoryInfo(model_path);
        if (dir.Exists)
        {
            DirectoryInfo[] childs = dir.GetDirectories();
            foreach (DirectoryInfo child in childs)
            {
                child.Delete(true);
            }
            dir.Delete(true);
        }
        Directory.CreateDirectory(model_path);
        CreateAndSaveDatasets(config, ImageDir_Train, Path.Combine(model_path, train_data_file), ImageDir_Val, Path.Combine(model_path, test_data_file), ext);
        
    }
    if (reCreateModel || File.Exists(Path.Combine(model_path, model_file)) == false)
    {
        CreateAndSaveModel(config, Path.Combine(BaseWorkPath, base_model_file[model_type]), Path.Combine(model_path, model_file), device);
    }
    config.Save(Path.Combine(model_path, config_file));

    // 训练
    _modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
    var dataSource = CreateDataSource(Path.Combine(model_path, train_data_file));
    var trainer = CreateTrainer();

    var minibatchesSeen = 0;
    int data_length = readFileLines(Path.Combine(model_path, train_data_file));
    while (true)
    {
        var minibatchData = dataSource.MinibatchSource.GetNextMinibatch(batch_size, device);
        var arguments = new Dictionary<Variable, MinibatchData>
        {
            { _modelWrapper.Input, minibatchData[dataSource.FeatureStreamInfo] },
            { _modelWrapper.TrainingOutput, minibatchData[dataSource.LabelStreamInfo] }
        };
        trainer.TrainMinibatch(arguments, device);
        double loss = trainer.PreviousMinibatchLossAverage();
        double eval = trainer.PreviousMinibatchEvaluationAverage();

        int epoch = Convert.ToInt32((minibatchesSeen * batch_size / data_length)) + 1;
        Console.WriteLine($"[{epoch}:{TrainNum}/{minibatchesSeen % (data_length / batch_size)+1}] CrossEntropyLoss = {loss}, EvaluationCriterion = {eval}");

        minibatchesSeen++;

        if ((Convert.ToInt32((minibatchesSeen * batch_size / data_length)) + 1) > TrainNum)
        {
            _modelWrapper.Model.Save(Path.Combine(model_path, model_file));
            break;
        }
    }
    RunModel = "val";
}

5.验证完整代码

if (RunModel == "val")
{
    //验证
    _modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
    var dataSource_test = CreateDataSource(Path.Combine(model_path, test_data_file));

    const int minibatchSize = 1;
    var currentMinibatch = 0;
    int Correct = 0;
    int Total = 0;
    while (true)
    {
        var minibatchData = dataSource_test.MinibatchSource.GetNextMinibatch(minibatchSize, device);
        var inputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.Input, minibatchData[dataSource_test.FeatureStreamInfo].data } };
        var outputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.EvaluationOutput, null } };

        _modelWrapper.Model.Evaluate(inputDataMap, outputDataMap, device);
        var outputVal = outputDataMap[_modelWrapper.EvaluationOutput];
        var actual = outputVal.GetDenseData<float>(_modelWrapper.EvaluationOutput);
        var labelBatch = minibatchData[dataSource_test.LabelStreamInfo].data;
        var expected = labelBatch.GetDenseData<float>(_modelWrapper.Model.Output);

        Func<IEnumerable<IList<float>>, IEnumerable<int>> maxSelector =
            (collection) => collection.Select(x => x.IndexOf(x.Max()));

        var actualLabels = maxSelector(actual);
        var expectedLabels = maxSelector(expected);

        Correct += actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 1 : 0).Sum();
        Total += actualLabels.Count();
        double acc = (Convert.ToDouble(Correct) / Total);
        Console.WriteLine($"Correct = {Correct}, Total = {Total}, acc={acc}");
        currentMinibatch++;
        if (minibatchData.Values.Any(x => x.sweepEnd))
        {
            break;
        }
    }
}

6.预测完整代码

if(RunModel=="test")
{
     //测试图片
     _modelWrapper = new CntkModelWrapper(Path.Combine(model_path, model_file), device);
     var config = ClassificationConfig.Load(Path.Combine(model_path, config_file));

     string[] all_image = Directory.GetFiles(ImageDir_Test, $"*.{ext}");
     foreach(string file in all_image)
     {
         var inputValue = new Value(new NDArrayView(new int[] { IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_DEPTH }, ImageHelper.Load(IMAGE_WIDTH, IMAGE_HEIGHT, file), device));
         var inputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.Input, inputValue } };
         var outputDataMap = new Dictionary<Variable, Value>() { { _modelWrapper.EvaluationOutput, null } };

         _modelWrapper.Model.Evaluate(inputDataMap, outputDataMap, device);
         var outputData = outputDataMap[_modelWrapper.EvaluationOutput].GetDenseData<float>(_modelWrapper.EvaluationOutput).First();

         var output = outputData.Select(x => (double)x).ToArray();
         var classIndex = Array.IndexOf(output, output.Max());
         var className = config.GetClassNameByIndex(classIndex);
         Console.WriteLine(file + " : " + className);
     }
 }

训练效果

在这里插入图片描述
在这里插入图片描述

预测效果

在这里插入图片描述

总结

源码私聊

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是一个简单的示例代码,展示了如何使用C#与三菱FX5U PLC基于TCP/IP技术进行通信。在此示例中,我们使用System.Net.Sockets命名空间中的TcpClient和NetworkStream类来建立连接和发送/接收数据。 ```csharp using System; using System.Net.Sockets; using System.Text; class Program { static void Main() { try { // PLC的IP地址和端口号 string ipAddress = "192.168.0.1"; int port = 5002; // 创建TCP客户端 TcpClient client = new TcpClient(ipAddress, port); // 获取网络流 NetworkStream stream = client.GetStream(); // 构造读取数据的请求消息(示例为读取D寄存器,起始地址为D100,读取长度为10) string readCommand = "500000FF03FF000A01010082D00064"; byte[] readData = StringToByteArray(readCommand); // 发送读取请求消息到PLC stream.Write(readData, 0, readData.Length); // 接收PLC的响应消息 byte[] responseBuffer = new byte[1024]; int bytesRead = stream.Read(responseBuffer, 0, responseBuffer.Length); string responseData = ByteArrayToString(responseBuffer, bytesRead); // 解析和处理PLC的响应数据 // TODO: 根据MELSEC Protocol的规范解析responseData // 关闭连接 stream.Close(); client.Close(); } catch (Exception ex) { Console.WriteLine("与PLC通信时发生错误:" + ex.Message); } } // 将十六进制字符串转换为字节数组 static byte[] StringToByteArray(string hex) { int length = hex.Length / 2; byte[] bytes = new byte[length]; for (int i = 0; i < length; i++) { bytes[i] = Convert.ToByte(hex.Substring(i * 2, 2), 16); } return bytes; } // 将字节数组转换为十六进制字符串 static string ByteArrayToString(byte[] bytes, int length) { StringBuilder sb = new StringBuilder(length * 2); for (int i = 0; i < length; i++) { sb.Append(bytes[i].ToString("X2")); } return sb.ToString(); } } ``` 请注意,这只是一个基本的示例代码,具体的实现细节可能因PLC型号、通信协议和你的需求而有所不同。你需要根据三菱FX5U PLC的通信手册和通信协议规范进行相应的修改和调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱学习的广东仔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值