如何将训练好的Python模型给JavaScript使用?

前言

  从前面的Tensorflow环境搭建到目标检测模型迁移学习,已经完成了一个简答的扑克牌检测器,不管是从图片还是视频都能从画面中识别出有扑克的目标,并标识出扑克点数。但是,我想在想让他放在浏览器上可能实际使用,那么要如何让Tensorflow模型转换成web格式的呢?接下来将从实践的角度详细介绍一下部署方法!

环境

  • Windows10
  • Anaconda3
  • TensorFlow.js converter

converter介绍

  Converter全名是TensorFlow.js Converter,他可以将TensorFlow GraphDef模型(通过Python API创建的,可以先理解为Python模型) 转换成Tensorflow.js可读取的模型格式(json格式), 用于在浏览器上对指定数据进行推算。

 

converter安装

  为了不影响前面目标检测训练环境,这里我用conda创建了一个新的Python虚拟环境,Python版本3.6.8。在安装转换器的时候,如果当前环境没有Tensorflow,默认会安装与TF相关的依赖,只需要进入指定虚拟环境,输入以下命令。

pip install tensorflowjs

converter用法

tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model

1. 产生的文件(生成的web格式模型)

转换器命令执行后生产两种文件,分别是model.json (数据流图和权重清单)和group1-shard\*of\* (二进制权重文件)

2. 输入的必要条件(命令参数和选项[带--为选项])

converter转换指令后面主要携带四个参数,分别是输入模型的格式,输出模型的格式,输入模型的路径,输出模型的路径,更多帮助信息可以通过以下命令查看,另附命令分解图。

tensorflowjs_converter --help

2.1. --input_format

要转换的模型的格式,SavedModel 为 tf_saved_model, frozen model 为 tf_frozen_model, session bundle 为 tf_session_bundle, TensorFlow Hub module 为 tf_hub,Keras HDF5 为 keras。

2.2. --output_format

输出模型的格式, 分别有tfjs_graph_model (tensorflow.js图模型,保存后的web模型没有了再训练能力,适合SavedModel输入格式转换),tfjs_layers_model(tensorflow.js层模型,具有有限的Keras功能,不适合TensorFlow SavedModels转换)。

2.3. input_path

saved model, session bundle 或 frozen model的完整的路径,或TensorFlow Hub模块的路径。

2.4. output_path

输出文件的保存路径。

2.5. --saved_model_tags

只对SavedModel转换用的选项:输入需要加载的MetaGraphDef相对应的tag,多个tag请用逗号分隔。默认为 serve

2.6. --signature_name

对TensorFlow Hub module和SavedModel转换用的选项:对应要加载的签名,默认为default

2.7. --output_node_names

输出节点的名字,每个名字用逗号分离。

3. 常用的两组命令行

1. covert from saved_model

tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model

2. convert from frozen_model
tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='num_detections,detection_boxes,detection_scores,detection_classes' ./frozen_inference_graph.pb  ./web_modelk

开始实践

1. 找到通过export_inference_graph.py导出的模型

导出的模型在项目的inference_graph文件夹(models\research\object_detection)里,frozen_inference_graph.pb是 tf_frozen_model输入格式需要的,而saved_model文件夹就是tf_saved_model格式。在当前目录下新建web_model目录,用于存储转换后的web格式的模型。

2. 开始转换

在当前虚拟环境下,进入到inference_graph目录下,输入以下命令,之后就会在web_model生成一个json文件和多个权重文件。

tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model

3. 浏览器端部署

3.1. 创建一个前端项目,将web_model放入其中。

3.2.编写代码

<!doctype html>
<head>
  <link rel="stylesheet" href="tfjs-examples.css" />
  <style>
  canvas {outline: orange 2px solid; margin: 10px 0;}
  </style>
</head>

<body>
  <div class="tfjs-example-container centered-container">
    <section class='title-area'>
      <h1>赌圣2023</h1>
    </section>
    <p class='section-head'>模型描述</p>
    <p>我看你怎么出老千!</p>
    <p class='section-head'>模型状态</p>
    <div id="status">加载模型中...</div>
    <div>
      <p class='section-head'>效果展示</p>
      <p></button><input type="file" accept="image/*" id="test"/></p>
      <canvas id="data-canvas" width="300" height="1100"></canvas>
    </div>
  </div>

</body>

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>

<script>
  const canvas = document.getElementById('data-canvas');
  const status = document.getElementById('status');
  const testModel = document.getElementById('test');

  const BOUNDING_BOX_LINE_WIDTH = 3;
  const BOUNDING_BOX_STYLE1 = 'rgb(0,0,255)';
  const BOUNDING_BOX_STYLE2 = 'rgb(0,255,0)';

  async function init() {

    const LOCAL_MODEL_PATH = './web_model/model.json';

    // 将本地模型保存到浏览器
    // tf.sequential().save

    // 加载本地模型
    let model;
    try {
      model = await tf.loadGraphModel(LOCAL_MODEL_PATH);
      testModel.disabled = false;
      status.textContent = '成功加载本地模型!请亮出你的牌吧';
      
      // 默认扑克牌
      runAndVisualizeInference('./cam_image39.jpg', model)
      
    } catch (err) {
      console.log('加载本地模型错误:', err);
      status.textContent = '加载本地模型失败';
    }

    testModel.addEventListener('change', (e) => {
      runAndVisualizeInference(e, model)
    });
}

async function runAndVisualizeInference(e, model) {

  if (typeof e === 'string') {
    await new Promise((resolve, reject) => {
      // 图片显示在canvas中
      var img = new Image;
      img.src = e;
      img.onload = function () { // 必须onload之后再画
        let w = 500;
        let h = img.height/img.width*500;
        canvas.width = w;
        canvas.height = h;
        var ctx = canvas.getContext('2d');
        ctx.drawImage(img,0,0,w,h);
        resolve();
      }
    })
  } else {

    // 上传图片并显示在canvas中
    var file = e.target.files[0]; 
    if (!/image\/\w+/.test(file.type)) {
      alert("请确保文件为图像类型");
      return false;
    }
    var reader = new FileReader();
    reader.readAsDataURL(file); // 转化成base64数据类型
    await new Promise((resolve, reject) => {
      reader.onload = function (e) {
        // 图片显示在canvas中
        var img = new Image;
        img.src = this.result;
        img.onload = function () { // 必须onload之后再画
          let w = 500;
          let h = img.height/img.width*500;
          canvas.width = w;
          canvas.height = h;
          var ctx = canvas.getContext('2d');
          ctx.drawImage(img,0,0,w,h);
          resolve();
        }
      }
    })
  }

  // 模型输入处理
  let image = tf.browser.fromPixels(canvas);
  const t4d = image.expandDims(0);

  const outputDim = [
    'num_detections', 'detection_boxes', 'detection_scores',
    'detection_classes'
  ];
  
  const labelMap = {
    1: '九点',
    2: '十点',
    3: 'Jack',
    4: 'Queen',
    5: 'King',
    6: 'Ace'
  }
  
  let modelOut = {}, boxes = [], w = canvas.width, h = canvas.height;
  console.log(model)
  
  for (const dim of outputDim) {
    let tensor = await model.executeAsync({
      'image_tensor': t4d
    }, `${dim}:0`);
    modelOut[dim] = await tensor.data();
  }
  console.log(modelOut)
  
  for (let i=0; i<modelOut['detection_scores'].length; i++) {
    const score = modelOut['detection_scores'][i];
  
    if (score < 0.5) break; // 置信度过滤
  
    boxes.push({
      ymin: modelOut['detection_boxes'][i*4]*h,
      xmin: modelOut['detection_boxes'][i*4+1]*w,
      ymax: modelOut['detection_boxes'][i*4+2]*h,
      xmax: modelOut['detection_boxes'][i*4+3]*w,
      label: labelMap[modelOut['detection_classes'][i]],
    })
  }
  
  console.log(boxes)

  // 可视化检测框
  drawBoundingBoxes(canvas, boxes);

  // 张量运行内存清除
  tf.dispose([image, modelOut]);
}

function drawBoundingBoxes(canvas, predictBoundingBoxArr) {
  for (const box of predictBoundingBoxArr) {
    let left = box.xmin;
    let right = box.xmax;
    let top = box.ymin;
    let bottom = box.ymax;

    const ctx = canvas.getContext('2d');
    ctx.beginPath();
    ctx.strokeStyle = box.label==='ZERO_DEV'?BOUNDING_BOX_STYLE1:BOUNDING_BOX_STYLE2;
    ctx.lineWidth = BOUNDING_BOX_LINE_WIDTH;
    ctx.moveTo(left, top);
    ctx.lineTo(right, top);
    ctx.lineTo(right, bottom);
    ctx.lineTo(left, bottom);
    ctx.lineTo(left, top);
    ctx.stroke();

    ctx.font = '24px Arial bold';
    ctx.fillStyle = box.label==='zfc'?BOUNDING_BOX_STYLE2:BOUNDING_BOX_STYLE1;
    ctx.fillText(box.label, left+8, top+8);
  }
}

init();

</script>

3.3. 运行结果

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
训练好的YOLO(You Only Look Once)模型部署在网页上通常需要以下步骤: 1. 准备环境:在网页端部署YOLO模型,需要一台具备服务器功能的主机。可以选择自己搭建服务器,或者使用云服务器服务提供商提供的服务。 2. 安装相应软件:在服务器上安装必要的软件包和库,如PythonTensorFlow、OpenCV等。这些软件包可以用于模型的加载、图像处理等任务。 3. 导入模型:将训练好的YOLO模型文件(一般包括.cfg、.weights、 .names文件)导入到服务器。这些文件描述了模型的结构、权重和标签等信息。 4. 编写服务端代码:使用Python等编程语言,编写服务器端代码。这个代码需要负责接收网页端的请求,并调用YOLO模型进行图像识别或目标检测等任务。 5. 前端开发:在网页端,可以使用HTML、CSS和JavaScript等前端开发工具,创建用户界面和交互。通过用户界面,用户可以上传图像,并接收服务器端返回的识别结果。 6. 后端交互:前端页面通过Ajax等技术与后端服务器进行交互,将用户上传的图像发送给服务器,并接收服务器返回的识别结果。 7. 图像处理:服务器接收到图像后,使用OpenCV等库对图像进行预处理和调整大小等操作,以满足YOLO模型的输入要求。 8. 模型推理:服务器端使用导入的YOLO模型进行图像识别或目标检测。根据模型的输出,可以得出图像中存在的目标物体、位置和类别等信息。 9. 返回结果:服务器将识别结果(如目标位置、类别等)以JSON格式返回给前端页面。前端页面可以根据这些结果,显示或绘制边界框等视觉效果。 10. 调试与优化:进行测试和调试,确保网页端与服务器的通信正常,并保证YOLO模型在网页上的推理速度和准确性。 总之,将训练好的YOLO模型部署在网页上需要搭建服务器环境、导入模型文件、编写服务端代码、前端开发和后端交互等步骤,以实现图像识别和目标检测等功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值