基于Tensorflow.js的花卉识别编程实践

使用TensorFlow.js进行编程有许多优点,特别适合开发机器学习和深度学习的应用。TensorFlow.js可以直接在浏览器中运行,无需服务器或特殊环境配置。这使得开发者可以轻松地创建和部署基于Web的机器学习应用。TensorFlow.js提供了许多预训练模型,开发者可以直接使用这些模型进行各种任务,如图像分类、物体检测、自然语言处理等,减少了从头开始训练模型的时间和资源。
使用Tensorflow.js,所有的计算都在本地进行,数据不会被发送到服务器,有助于保护用户的隐私。由于数据不离开用户的设备,减少了数据泄露的风险。
下面以基于TensorFlow.js的花卉识别为例,讲述编程涉及的关键问题。
主要功能是:
(1)用户选择图片。
(2)选择花朵的2种方式:

  • 双击鼠标:已鼠标位置为中心,固定大小的矩形框确定花朵区域。
  • 拖动鼠标框选花朵区域。

(3)裁剪图片,获取花朵图片,识别花朵类别,显示识别结果。

深度学习的训练模型需要保存为SavedModel的模型,再转换为Tensorflow.js格式模型(1个json文件+几个bin文件)。
下面对关键的代码进行说明。

index.html中需要加载Tensorflow.js库

<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

花卉识别的主要功能均在ai_flower.js实现,下面介绍其中的关键函数。

加载模型和标签
  • loadModel():异步加载TensorFlow模型。加载成功后更新页面提示。
async function loadModel() {
  try {
    // 加载Savedmodel转换模型
    model = await tf.loadGraphModel(modelUrl);
    
    document.getElementById('result').innerText = "Model loaded successfully.";
    
    // console.log("Model input:", model.input)
    // console.log(model.summary())

  } catch (error) {
    console.error("Error loading model:", error);
    document.getElementById('predict').disabled = true;
    document.getElementById('result').innerText = "Error loading model.";
  }
}
  • loadLabelMap(label_map_path, label_encn_path):异步加载类别标签和英文-中文对照表。
async function loadLabelMap(label_map_path, label_encn_path){
  try {
    // 获取标签类别名称
    const responseMap = await fetch(label_map_path);
    labelMap = await responseMap.json();

    const responseEncn = await fetch(label_encn_path);
    leabelEnCn = await responseEncn.json();
    
  } catch (error) {
    console.error("加载类别标签失败:", error);
    document.getElementById('result').innerText = "加载类别标签失败。";
  }
}
处理图片上传
  • handleImageUpload(event):处理图片选择事件,清空上次识别结果,加载新的图片并显示在页面上。同时创建临时图像对象获取原始图像尺寸,并更新覆盖画布的尺寸和位置。
function handleImageUpload(event) {
  // 清空上次识别结果
  document.getElementById('result').innerText = "";
  document.getElementById('promptText').innerText = "";

  // 清除之前的裁剪图片
  const croppedImgElement = document.getElementById('croppedImage');
  croppedImgElement.src = '';
  croppedImgElement.style.display = 'none';
  // 清除虚线框
  overlay.style.display = 'none';

  // 禁用预测按钮
  document.getElementById('predict').disabled = true; 

  // 未确定矩形框
  isSelectRect = false;
  
  // 获取 imgElement
  const imgElement = document.getElementById('image');
  imgElement.style.display = 'none';

  // 清除之前的虚线框
  const overlayCanvas = document.getElementById('overlayCanvas');
  overlayCanvas.style.display = 'none';

  // 加载图片
  const file = event.target.files[0];
  const reader = new FileReader();
  
  reader.onload = function(e) {
    const imgElement = document.getElementById('image');
    imgElement.src = e.target.result;
    imgElement.style.display = 'block';
    
    // 创建一个临时的 Image 对象来获取原始尺寸
    const tempImg = new Image();
    tempImg.onload = function() {
      // 保存原始图像数据和尺寸
      const canvas = document.createElement('canvas');
      canvas.width = tempImg.naturalWidth;
      canvas.height = tempImg.naturalHeight;
      const ctx = canvas.getContext('2d');
      ctx.drawImage(tempImg, 0, 0);

      imgElement.dataset.originalImage = canvas.toDataURL();
      imgElement.dataset.naturalWidth = tempImg.naturalWidth;
      imgElement.dataset.naturalHeight = tempImg.naturalHeight;

      // 更新 overlayCanvas 尺寸和位置
      const overlayCanvas = document.getElementById('overlayCanvas');
      overlayCanvas.width = tempImg.naturalWidth;
      overlayCanvas.height = tempImg.naturalHeight;
      overlayCanvas.style.display = 'block';

      // 确保 canvas 与 imgElement 对齐
      const imgRect = imgElement.getBoundingClientRect();
      overlayCanvas.style.position = 'absolute';
      overlayCanvas.style.left = `${imgRect.left}px`;
      overlayCanvas.style.top = `${imgRect.top}px`;
      
      // 添加提示文本
      document.getElementById('promptText').innerText = '请拖动鼠标,框选花朵。';
    };

    tempImg.src = e.target.result;
  };
  
  reader.readAsDataURL(file);
}
裁剪并显示图像
  • cropImage(cropStartX, cropStartY, cropEndX, cropEndY):根据用户拖动的矩形框或双击鼠标确定的矩形框,裁剪图像,并显示裁剪后的图像(方便调试)。裁剪区域相对于原始图像进行计算,以确保裁剪的准确性。
function cropImage(cropStartX, cropStartY, cropEndX, cropEndY) {
  const imgElement = document.getElementById('image');
  const originalImageData = imgElement.dataset.originalImage;
  const naturalWidth = parseInt(imgElement.dataset.naturalWidth, 10);
  const naturalHeight = parseInt(imgElement.dataset.naturalHeight, 10);

  // Get the image's bounding rectangle
  const imgRect = imgElement.getBoundingClientRect();

  // Calculate the scaling factors
  const scaleX = naturalWidth / imgElement.width;
  const scaleY = naturalHeight / imgElement.height;

  // Convert displayed coordinates to original image coordinates
  const sx = cropStartX * scaleX;
  const sy = cropStartY * scaleY;
  const ex = cropEndX * scaleX;
  const ey = cropEndY * scaleY;

  const width = ex - sx;
  const height = ey - sy;

  // Ensure the crop area is within image bounds
  const adjustedStartX = Math.max(0, Math.min(sx, naturalWidth - width));
  const adjustedStartY = Math.max(0, Math.min(sy, naturalHeight - height));

  const canvas = document.createElement('canvas');
  const ctx = canvas.getContext('2d');

  // Set canvas dimensions
  canvas.width = cropEndX - cropStartX;
  canvas.height = cropEndY - cropStartY;

  // Create a temporary image object to load the original image data
  const tempImg = new Image();
  tempImg.onload = function() {
    // Draw the cropped area on the canvas
    ctx.drawImage(tempImg, adjustedStartX, adjustedStartY, width, height, 0, 0, canvas.width, canvas.height);

    const croppedImgElement = document.getElementById('croppedImage');
    croppedImgElement.src = canvas.toDataURL();
    croppedImgElement.style.display = 'block';
  };

  tempImg.src = originalImageData;

  document.getElementById('predict').disabled = false; // 启用预测按钮
  document.getElementById('promptText').innerText = ''

}
预测图像
  • predictImage():使用加载的模型对裁剪后的图像进行预测。获取图像元素并将其转换为TensorFlow张量。调整图像大小,归一化并添加批量维度后进行预测。解析预测结果并显示前k个预测结果。
async function predictImage() {
  if (!model) {
    document.getElementById('result').innerText = "模型尚未加载。";
    return;
  }

  if (!isSelectRect) {
    document.getElementById('result').innerText = "请拖动鼠标,框选花朵。";
    return;
  }

  // 获取图像元素
  const imgElement = document.getElementById('croppedImage');
  // 从图像元素创建张量
  const tensorImg = tf.browser.fromPixels(imgElement).toFloat();

  // 调整为模型需要的输入大小
  const resizedImg = tf.image.resizeBilinear(tensorImg, [224, 224]); 
  // 归一化图像
  const normalizedImg = resizedImg.div(255.0);
  // 添加批量维度
  const batchedImg = normalizedImg.expandDims(0);

  // 进行预测
  let predictions;
  try {
    predictions = await model.execute(batchedImg);
  } catch (error) {
    console.error("模型预测失败:", error);
    document.getElementById('result').innerText = "模型预测失败。";
    return;
  }

  // 检查 predictions 是否有效
  if (!predictions || Array.isArray(predictions) && predictions.length === 0) {
    console.error("模型预测返回了无效的输出。");
    document.getElementById('result').innerText = "模型预测返回了无效的输出。";
    return;
  }

  // 获取第一个 Tensor 作为输出
  // 实际不是数组
  const outputTensor = Array.isArray(predictions) ? predictions[0] : predictions;

  if (!outputTensor) {
    console.error("输出 Tensor 未定义");
    document.getElementById('result').innerText = "输出 Tensor 未定义。";
    return;
  }

  // 将 Tensor 转换为数组
  let probabilities;
  try {
    probabilities = await outputTensor.data(); // 使用 .data() 而不是 .array()
    // console.log("probabilities:")
    // console.log(probabilities)
  } catch (error) {
    console.error("Tensor 转换为数组失败:", error);
    document.getElementById('result').innerText = "Tensor 转换为数组失败。";
    return;
  }

  // 获取 top-k 预测
  const topKIndices = Array.from(probabilities)
                             .map((prob, index) => ({prob, index}))
                             .sort((a, b) => b.prob - a.prob)
                             .slice(0, topK)
                             .map(item => item.index);
 
  const topKProbabilities = topKIndices.map(index => probabilities[index]);

  // console.log(topKIndices)
  // console.log(topKProbabilities)

  // 显示 top-k 预测结果, label编号从1开始(index+1),不是从0开始
  let resultText = "";
  topKIndices.forEach((index, i) => {
     const className = labelMap[index + 1] || "Unknown"; // 从字典中获取类别名称
     const cnName = leabelEnCn[className];
     const probability = topKProbabilities[i];
     resultText += `${className}-${cnName}: ${probability.toFixed(4)}\n`; // 保留四位小数
   });

  document.getElementById('result').innerText = resultText;
}

下载完整源代码

  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值