一、新建工程并测试
Python中的TensorFlow用的人比较多,C#中用的人比较少,在这里简单介绍一下TensorFlowSharp的使用
Github项目地址
https://github.com/migueldeicaza/TensorFlowSharp
CSDN项目地址
https://gitee.com/mirrors/TensorFlowSharp
TensorFlowSharp封装了比较低级别的TensorFlow API,功能与其他语言相通。但是目前没有像Python绑定那样包含高级别的API,因此对于那些高级操作来说,使用它更麻烦。您可以使用Python中的TensorFlow或Keras进行原型,然后保存您的图表或经过训练的模型,然后使用TensorFlowSharp在.NET中加载结果,并为您自己的数据提供训练或运行。
新建.net 控制台工程,这里选择.net 4.7.1通过nuget安装TensorFlowSharp,测试代码如下
using System;
using TensorFlow;
namespace TensorFlowSimilar
{
class Program
{
static void Main(string[] args)
{
using (var session = new TFSession())
{
var graph = session.Graph;
Console.WriteLine($"TFCore.Version={TFCore.Version}");
TFOutput a = graph.Const(5);
TFOutput b = graph.Const(6);
Console.WriteLine($"a={session.GetRunner().Run(a)},b={session.GetRunner().Run(b)}");
// 两常量加
var addingResults = session.GetRunner().Run(graph.Add(a, b));
var addingResultValue = addingResults.GetValue();
Console.WriteLine($"a+b={addingResultValue}");
// 两常量乘
var multiplyResults = session.GetRunner().Run(graph.Mul(a, b));
var multiplyResultValue = multiplyResults.GetValue();
Console.WriteLine($"a*b={multiplyResultValue}");
}
Console.ReadKey();
}
}
}
测试结果,哪个Log提示说是不支持AVX2,可以占时忽略他
二、 图片内容识别
下面介绍一下TensorFlowSharp的高级应用,图片内容识别,Github工程地址
https://github.com/migueldeicaza/TensorFlowSharp/tree/master/Examples/ExampleInceptionInference
其中使用了训练好的模型,如果自己训练更多的模型,匹配效果会更好
看下结果,老虎竟然是最佳匹配,长颈鹿竟然没识别到,应该是训练模型中没有
源代码
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using System.Net;
using TensorFlow;
namespace TensorFlowSimilar
{
class ImageRecognition
{
static string dir, modelFile, labelsFile;
public static void Main(string[] args)
{
dir = "temp";
List<string> files = Directory.GetFiles("image").ToList();
ModelFiles(dir);
var graph = new TFGraph();
var model = File.ReadAllBytes(modelFile);
graph.Import(model, "");
using (var session = new TFSession(graph))
{
var labels = File.ReadAllLines(labelsFile);
foreach (var file in files)
{
//对图像文件运行推断,对于多个镜像,session.Run()可以在循环调用,图像可以批处理,因为模型接收批量图像数据作为输入
var tensor = CreateTensorFromImageFile(file);
var runner = session.GetRunner();
runner.AddInput(graph["input"][0], tensor).Fetch(graph["output"][0]);
var output = runner.Run();
var result = output[0];
var rshape = result.Shape;
if (result.NumDims != 2 || rshape[0] != 1)
{
var shape = "";
foreach (var d in rshape)
{
shape += $"{d} ";
}
shape = shape.Trim();
Console.WriteLine($"Error: expected to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape [{shape}]");
Environment.Exit(1);
}
bool jagged = true;
var bestIdx = 0;
float best = 0;
if (jagged)
{
var probabilities = ((float[][])result.GetValue(jagged: true))[0];
for (int i = 0; i < probabilities.Length; i++)
{
if (probabilities[i] > best)
{
bestIdx = i;
best = probabilities[i];
}
}
}
else
{
var val = (float[,])result.GetValue(jagged: false);
for (int i = 0; i < val.GetLength(1); i++)
{
if (val[0, i] > best)
{
bestIdx = i;
best = val[0, i];
}
}
}
Console.WriteLine($"文件名:{Path.GetFileName(file)} 最佳匹配:[{bestIdx}] {best * 100.0}% 匹配对象:{labels[bestIdx]}");
}
}
Console.ReadKey();
}
static TFTensor CreateTensorFromImageFile(string file)
{
var contents = File.ReadAllBytes(file);
var tensor = TFTensor.CreateString(contents);
ConstructGraphToNormalizeImage(out TFGraph graph, out TFOutput input, out TFOutput output);
using (var session = new TFSession(graph))
{
var normalized = session.Run(
inputs: new[] { input },
inputValues: new[] { tensor },
outputs: new[] { output });
return normalized[0];
}
}
static void ConstructGraphToNormalizeImage(out TFGraph graph, out TFOutput input, out TFOutput output)
{
const int W = 224;
const int H = 224;
const float Mean = 117;
const float Scale = 1;
graph = new TFGraph();
input = graph.Placeholder(TFDataType.String);
output = graph.Div(
x: graph.Sub(
x: graph.ResizeBilinear(
images: graph.ExpandDims(
input: graph.Cast(
graph.DecodeJpeg(contents: input, channels: 3), DstT: TFDataType.Float),
dim: graph.Const(0, "make_batch")),
size: graph.Const(new int[] { W, H }, "size")),
y: graph.Const(Mean, "mean")),
y: graph.Const(Scale, "scale"));
}
//下载模型库
static void ModelFiles(string dir)
{
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
modelFile = Path.Combine(dir, "tensorflow_inception_graph.pb");
labelsFile = Path.Combine(dir, "imagenet_comp_graph_label_strings.txt");
var zipfile = Path.Combine(dir, "inception5h.zip");
if (File.Exists(modelFile) && File.Exists(labelsFile))
{
return;
}
Directory.CreateDirectory(dir);
var wc = new WebClient();
wc.DownloadFile(url, zipfile);
ZipFile.ExtractToDirectory(zipfile, dir);
File.Delete(zipfile);
}
}
}
识别用的图片,放到Debug目录下的image文件夹中
精简工程下载地址