相关资料:
【人工智能笔记】第十节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(上)
【人工智能笔记】第十一节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(中)
这一节,会介绍如何使用现实素材继续训练模型,来完成真正的仪表识别。首先我们会使用标注工具对素材进行标注,然后使用实际标注素材进行训练,编写相应的训练代码。下面,我会逐步介绍如何实现。
先上效果图,估计值做了标准换算,范围0-100:
一、使用标注工具进行素材标注
标注工具使用labelme,需使用我修改过的版本进行标注,源码:GitHub - tfwcn/labelme: Image Polygonal Annotation with Python (polygon, rectangle, line, point and image-level flag annotation).
1. 下载源码
git clone https://github.com/tfwcn/labelme.git
2. 安装环境
安装pillow,推荐使用adaconda
pip install pillow
pip install qtpy5
pip install labelme
3. 启动标注工具
cd labelme
python main.py
4. 标注素材
标注素材时,默认最小刻度为0,最大刻度为100。把刻度范围标准化方便训练。
注:素材数量要超过2万效果才会好。
二、训练
训练脚本,在ai_api\gauge下新建train_label.py,内容如下:
import os
import sys
import numpy as np
import cv2
import tensorflow as tf
import random
import json
import argparse
sys.path.append(os.getcwd())
import ai_api.utils.image_helpler as image_helpler
import ai_api.utils.file_helper as file_helper
from ai_api.gauge.gauge_model import GaugeModel
# 把模型的变量分布在哪个GPU上给打印出来
# tf.debugging.set_log_device_placement(True)
# 启动参数
parser = argparse.ArgumentParser()
parser.add_argument('--file_path', default='image_data')
parser.add_argument('--batch_size', default=6, type=int)
args = parser.parse_args()
model = GaugeModel.GetStaticModel()
def generator(file_list, batch_size):
X = []
Y = []
# 用于数据平均
value_index = 0
skip_index = 0
while True:
# 打乱列表顺序
random_list = np.array(file_list)
np.random.shuffle(random_list)
for file_path in random_list:
try:
# 读取json文件
with open(file_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
# json文件目录
json_dir = os.path.dirname(file_path)
# json文件名
json_name = os.path.basename(file_path)
# 图片路径
image_path = os.path.join(
json_dir, json_data['imagePath'].replace('\\', '/'))
# 图片文件名
image_name = os.path.basename(image_path)
# 值
value = json_data['shapes'][0]["label"].split('_')
if len(value)>1:
value = float(value[1])
else:
value = float(value[0])
if skip_index > 50:
skip_index = 0
# 值下标加1
value_index += 1
if value_index > 10:
value_index = 0
# 值平均
if (value_index-1)*10 >= value or value > value_index*10:
skip_index += 1
continue
# print('添加:', (value // 10), value_index)
# print('添加:', value)
skip_index = 0
# 值下标加1
value_index += 1
if value_index > 10:
value_index = 0
# 原始点列表
json_points = np.float32(json_data['shapes'][0]['points'])
# 点匹配
point_center_x = (min(json_points[:, 0]) + max(json_points[:, 0])) / 2
point_center_y = (min(json_points[:, 1]) + max(json_points[:, 1])) / 2
for p in json_points:
if p[0] < point_center_x and p[1] < point_center_y:
pointLT = p
elif p[0] > point_center_x and p[1] < point_center_y:
pointRT = p
elif p[0] < point_center_x and p[1] > point_center_y:
pointLB = p
elif p[0] > point_center_x and p[1] > point_center_y:
pointRB = p
points = np.float32([pointLT, pointLB, pointRT, pointRB])
# print('points:', points)
# 读取图片
img = image_helpler.fileToOpencvImage(image_path)
# 缩放图片
img, points, _ = image_helpler.opencvProportionalResize(
img, (400, 400), points=points)
# print('imgType:', type(img))
# width, height = image_helpler.opencvGetImageSize(img)
# print('imgSize:', width, height)
# 获取随机变换图片及标签
random_img, target_data = model.get_random_data(
img, value, target_points=points)
X.append(random_img)
# Y.append([value])
Y.append(target_data)
if len(Y) == batch_size:
result_x = np.array(X)
result_y = np.array(Y)
# print('generator', result_x.shape, result_y.shape)
yield result_x, result_y
X = []
Y = []
except Exception as expression:
print('异常:', expression, file_path)
def train():
'''训练'''
train_path = args.file_path
file_list = file_helper.ReadFileList(train_path, r'.json$')
print('图片数:', len(file_list))
# 训练参数
batch_size = args.batch_size
steps_per_epoch = 200
epochs = 500
model.fit_generator(generator(file_list, batch_size),
steps_per_epoch, epochs, auto_save=True)
def main():
train()
if __name__ == '__main__':
main()
执行训练命令:
python .\ai_api\gauge\train_label.py --file_path 素材目录 --batch_size 8
三、识别
创建测试页面,在ai_api\static\gauge新建predict_image_read.html文件,内容如下:
<!DOCTYPE html>
<head>
<meta charset="utf-8">
<title>ECharts</title>
</head>
<body>
<input id="selectFile" type="file" />
<input id="btnSubmit" type="button" title="测试" value="测试"></input> 估计值:<span id="txtValue"></span>
<br />
<span>
原图:
<!-- 为ECharts准备一个具备大小(宽高)的Dom -->
<img id="main" style="height:400px; width: 400px; vertical-align: top;"></img>
</span>
<br />
<span>
缩放后的测试图:
<img id="random_img" style="height:400px; width: 400px; vertical-align: top;"></img>
</span>
<span>
算法纠正后的图:
<img id="perspective_img" style="height:400px; width: 400px; vertical-align: top;"></img>
</span>
<script src="https://unpkg.com/axios/dist/axios.min.js"></script>
<script type="text/javascript">
function SubmitFun() {
document.getElementById('txtValue').innerText = '检测中。。。'
var preview = document.querySelector('#main');
var file = document.querySelector('#selectFile').files[0];
var reader = new FileReader();
reader.addEventListener("load", function () {
preview.src = reader.result;
let img = reader.result;
// console.log('图片数据:', img);
axios.post('/ai_api/gauge/gauge_predict', {
img_data: img,
read: 1,
})
.then(function (response) {
console.log(response);
// alert('识别值:'+response.data.value[0][0]);
document.getElementById('random_img').src = 'data:image/jpg;base64,' + response.data.random_img;
document.getElementById('perspective_img').src = 'data:image/jpg;base64,' + response.data.perspective_img;
document.getElementById('txtValue').innerText = response.data.value[0][0] * 100;
})
.catch(function (error) {
console.log(error);
});
}, false);
if (file) {
reader.readAsDataURL(file);
}
}
document.getElementById('btnSubmit').addEventListener('click', SubmitFun);
</script>
</body>
启动服务:
python .\manage.py runserver 0.0.0.0:8000
浏览器打开页面测试:http://127.0.0.1:8000/static/gauge/predict_image_read.html
源码:
【人工智能笔记】第十节Tensorflow2.0实现指针仪表方向纠正及指针识别源码_tensorflow2.0图像分类-深度学习文档类资源-CSDN下载
解压密码:xoaEe2io3h324