【人工智能笔记】第十二节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(下)

相关资料:

【人工智能笔记】第十节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(上)

【人工智能笔记】第十一节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(中)

这一节,会介绍如何使用现实素材继续训练模型,来完成真正的仪表识别。首先我们会使用标注工具对素材进行标注,然后使用实际标注素材进行训练,编写相应的训练代码。下面,我会逐步介绍如何实现。

先上效果图,估计值做了标准换算,范围0-100:

一、使用标注工具进行素材标注

标注工具使用labelme,需使用我修改过的版本进行标注,源码:GitHub - tfwcn/labelme: Image Polygonal Annotation with Python (polygon, rectangle, line, point and image-level flag annotation).

1. 下载源码

git clone https://github.com/tfwcn/labelme.git

2. 安装环境

安装pillow,推荐使用adaconda

pip install pillow
pip install qtpy5
pip install labelme

3. 启动标注工具

cd labelme
python main.py

4. 标注素材

标注素材时,默认最小刻度为0,最大刻度为100。把刻度范围标准化方便训练。

 注:素材数量要超过2万效果才会好。

二、训练

训练脚本,在ai_api\gauge下新建train_label.py,内容如下:

import os
import sys
import numpy as np
import cv2
import tensorflow as tf
import random
import json
import argparse

sys.path.append(os.getcwd())
import ai_api.utils.image_helpler as image_helpler
import ai_api.utils.file_helper as file_helper
from ai_api.gauge.gauge_model import GaugeModel

# 把模型的变量分布在哪个GPU上给打印出来
# tf.debugging.set_log_device_placement(True)

# 启动参数
parser = argparse.ArgumentParser()
parser.add_argument('--file_path', default='image_data')
parser.add_argument('--batch_size', default=6, type=int)
args = parser.parse_args()

model = GaugeModel.GetStaticModel()


def generator(file_list, batch_size):
    X = []
    Y = []
    # 用于数据平均
    value_index = 0
    skip_index = 0
    while True:
        # 打乱列表顺序
        random_list = np.array(file_list)
        np.random.shuffle(random_list)
        for file_path in random_list:
            try:
                # 读取json文件
                with open(file_path, 'r', encoding='utf-8') as f:
                    json_data = json.load(f)
                # json文件目录
                json_dir = os.path.dirname(file_path)
                # json文件名
                json_name = os.path.basename(file_path)
                # 图片路径
                image_path = os.path.join(
                    json_dir, json_data['imagePath'].replace('\\', '/'))
                # 图片文件名
                image_name = os.path.basename(image_path)
                # 值
                value = json_data['shapes'][0]["label"].split('_')
                if len(value)>1:
                    value = float(value[1])
                else:
                    value = float(value[0])
                if skip_index > 50:
                    skip_index = 0
                    # 值下标加1
                    value_index += 1
                    if value_index > 10:
                        value_index = 0
                # 值平均
                if (value_index-1)*10 >= value or value > value_index*10:
                    skip_index += 1
                    continue
                # print('添加:', (value // 10), value_index)
                # print('添加:', value)
                skip_index = 0
                # 值下标加1
                value_index += 1
                if value_index > 10:
                    value_index = 0
                # 原始点列表
                json_points = np.float32(json_data['shapes'][0]['points'])
                # 点匹配
                point_center_x = (min(json_points[:, 0]) + max(json_points[:, 0])) / 2
                point_center_y = (min(json_points[:, 1]) + max(json_points[:, 1])) / 2
                for p in json_points:
                    if p[0] < point_center_x and p[1] < point_center_y:
                        pointLT = p
                    elif p[0] > point_center_x and p[1] < point_center_y:
                        pointRT = p
                    elif p[0] < point_center_x and p[1] > point_center_y:
                        pointLB = p
                    elif p[0] > point_center_x and p[1] > point_center_y:
                        pointRB = p
                points = np.float32([pointLT, pointLB, pointRT, pointRB])
                # print('points:', points)
                # 读取图片
                img = image_helpler.fileToOpencvImage(image_path)
                # 缩放图片
                img, points, _ = image_helpler.opencvProportionalResize(
                    img, (400, 400), points=points)
                # print('imgType:', type(img))
                # width, height = image_helpler.opencvGetImageSize(img)
                # print('imgSize:', width, height)
                # 获取随机变换图片及标签
                random_img, target_data = model.get_random_data(
                    img, value, target_points=points)
                X.append(random_img)
                # Y.append([value])
                Y.append(target_data)
                if len(Y) == batch_size:
                    result_x = np.array(X)
                    result_y = np.array(Y)
                    # print('generator', result_x.shape, result_y.shape)
                    yield result_x, result_y
                    X = []
                    Y = []
            except Exception as expression:
                print('异常:', expression, file_path)


def train():
    '''训练'''
    train_path = args.file_path
    file_list = file_helper.ReadFileList(train_path, r'.json$')
    print('图片数:', len(file_list))
    # 训练参数
    batch_size = args.batch_size
    steps_per_epoch = 200
    epochs = 500
    model.fit_generator(generator(file_list, batch_size),
                        steps_per_epoch, epochs, auto_save=True)


def main():
    train()


if __name__ == '__main__':
    main()

执行训练命令:

python .\ai_api\gauge\train_label.py --file_path 素材目录 --batch_size 8

三、识别

创建测试页面,在ai_api\static\gauge新建predict_image_read.html文件,内容如下:

<!DOCTYPE html>

<head>
  <meta charset="utf-8">
  <title>ECharts</title>
</head>

<body>
  <input id="selectFile" type="file" />
  <input id="btnSubmit" type="button" title="测试" value="测试"></input> 估计值:<span id="txtValue"></span>
  <br />
  <span>
    原图:
    <!-- 为ECharts准备一个具备大小(宽高)的Dom -->
    <img id="main" style="height:400px; width: 400px; vertical-align: top;"></img>
  </span>
  <br />
  <span>
    缩放后的测试图:
    <img id="random_img" style="height:400px; width: 400px; vertical-align: top;"></img>
  </span>
  <span>
    算法纠正后的图:
    <img id="perspective_img" style="height:400px; width: 400px; vertical-align: top;"></img>
  </span>
  <script src="https://unpkg.com/axios/dist/axios.min.js"></script>
  <script type="text/javascript">
    function SubmitFun() {
      document.getElementById('txtValue').innerText = '检测中。。。'
      var preview = document.querySelector('#main');
      var file = document.querySelector('#selectFile').files[0];
      var reader = new FileReader();

      reader.addEventListener("load", function () {
        preview.src = reader.result;
        let img = reader.result;
        // console.log('图片数据:', img);
        axios.post('/ai_api/gauge/gauge_predict', {
          img_data: img,
          read: 1,
        })
          .then(function (response) {
            console.log(response);
            // alert('识别值:'+response.data.value[0][0]);
            document.getElementById('random_img').src = 'data:image/jpg;base64,' + response.data.random_img;
            document.getElementById('perspective_img').src = 'data:image/jpg;base64,' + response.data.perspective_img;
            document.getElementById('txtValue').innerText = response.data.value[0][0] * 100;
          })
          .catch(function (error) {
            console.log(error);
          });
      }, false);

      if (file) {
        reader.readAsDataURL(file);
      }
    }
    document.getElementById('btnSubmit').addEventListener('click', SubmitFun);
  </script>
</body>

启动服务:

python .\manage.py runserver 0.0.0.0:8000

浏览器打开页面测试:http://127.0.0.1:8000/static/gauge/predict_image_read.html

源码:

【人工智能笔记】第十节Tensorflow2.0实现指针仪表方向纠正及指针识别源码_tensorflow2.0图像分类-深度学习文档类资源-CSDN下载 

解压密码:xoaEe2io3h324

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

PPHT-H

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值