使用TensorFlow.js进行编程有许多优点,特别适合开发机器学习和深度学习的应用。TensorFlow.js可以直接在浏览器中运行,无需服务器或特殊环境配置。这使得开发者可以轻松地创建和部署基于Web的机器学习应用。TensorFlow.js提供了许多预训练模型,开发者可以直接使用这些模型进行各种任务,如图像分类、物体检测、自然语言处理等,减少了从头开始训练模型的时间和资源。
使用Tensorflow.js,所有的计算都在本地进行,数据不会被发送到服务器,有助于保护用户的隐私。由于数据不离开用户的设备,减少了数据泄露的风险。
下面以基于TensorFlow.js的花卉识别为例,讲述编程涉及的关键问题。
主要功能是:
(1)用户选择图片。
(2)选择花朵的2种方式:
- 双击鼠标:已鼠标位置为中心,固定大小的矩形框确定花朵区域。
- 拖动鼠标框选花朵区域。
(3)裁剪图片,获取花朵图片,识别花朵类别,显示识别结果。
深度学习的训练模型需要保存为SavedModel的模型,再转换为Tensorflow.js格式模型(1个json文件+几个bin文件)。
下面对关键的代码进行说明。
index.html中需要加载Tensorflow.js库
<!-- Import TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
花卉识别的主要功能均在ai_flower.js实现,下面介绍其中的关键函数。
加载模型和标签
loadModel()
:异步加载TensorFlow模型。加载成功后更新页面提示。
async function loadModel() {
try {
// 加载Savedmodel转换模型
model = await tf.loadGraphModel(modelUrl);
document.getElementById('result').innerText = "Model loaded successfully.";
// console.log("Model input:", model.input)
// console.log(model.summary())
} catch (error) {
console.error("Error loading model:", error);
document.getElementById('predict').disabled = true;
document.getElementById('result').innerText = "Error loading model.";
}
}
loadLabelMap(label_map_path, label_encn_path)
:异步加载类别标签和英文-中文对照表。
async function loadLabelMap(label_map_path, label_encn_path){
try {
// 获取标签类别名称
const responseMap = await fetch(label_map_path);
labelMap = await responseMap.json();
const responseEncn = await fetch(label_encn_path);
leabelEnCn = await responseEncn.json();
} catch (error) {
console.error("加载类别标签失败:", error);
document.getElementById('result').innerText = "加载类别标签失败。";
}
}
处理图片上传
handleImageUpload(event)
:处理图片选择事件,清空上次识别结果,加载新的图片并显示在页面上。同时创建临时图像对象获取原始图像尺寸,并更新覆盖画布的尺寸和位置。
function handleImageUpload(event) {
// 清空上次识别结果
document.getElementById('result').innerText = "";
document.getElementById('promptText').innerText = "";
// 清除之前的裁剪图片
const croppedImgElement = document.getElementById('croppedImage');
croppedImgElement.src = '';
croppedImgElement.style.display = 'none';
// 清除虚线框
overlay.style.display = 'none';
// 禁用预测按钮
document.getElementById('predict').disabled = true;
// 未确定矩形框
isSelectRect = false;
// 获取 imgElement
const imgElement = document.getElementById('image');
imgElement.style.display = 'none';
// 清除之前的虚线框
const overlayCanvas = document.getElementById('overlayCanvas');
overlayCanvas.style.display = 'none';
// 加载图片
const file = event.target.files[0];
const reader = new FileReader();
reader.onload = function(e) {
const imgElement = document.getElementById('image');
imgElement.src = e.target.result;
imgElement.style.display = 'block';
// 创建一个临时的 Image 对象来获取原始尺寸
const tempImg = new Image();
tempImg.onload = function() {
// 保存原始图像数据和尺寸
const canvas = document.createElement('canvas');
canvas.width = tempImg.naturalWidth;
canvas.height = tempImg.naturalHeight;
const ctx = canvas.getContext('2d');
ctx.drawImage(tempImg, 0, 0);
imgElement.dataset.originalImage = canvas.toDataURL();
imgElement.dataset.naturalWidth = tempImg.naturalWidth;
imgElement.dataset.naturalHeight = tempImg.naturalHeight;
// 更新 overlayCanvas 尺寸和位置
const overlayCanvas = document.getElementById('overlayCanvas');
overlayCanvas.width = tempImg.naturalWidth;
overlayCanvas.height = tempImg.naturalHeight;
overlayCanvas.style.display = 'block';
// 确保 canvas 与 imgElement 对齐
const imgRect = imgElement.getBoundingClientRect();
overlayCanvas.style.position = 'absolute';
overlayCanvas.style.left = `${imgRect.left}px`;
overlayCanvas.style.top = `${imgRect.top}px`;
// 添加提示文本
document.getElementById('promptText').innerText = '请拖动鼠标,框选花朵。';
};
tempImg.src = e.target.result;
};
reader.readAsDataURL(file);
}
裁剪并显示图像
cropImage(cropStartX, cropStartY, cropEndX, cropEndY)
:根据用户拖动的矩形框或双击鼠标确定的矩形框,裁剪图像,并显示裁剪后的图像(方便调试)。裁剪区域相对于原始图像进行计算,以确保裁剪的准确性。
function cropImage(cropStartX, cropStartY, cropEndX, cropEndY) {
const imgElement = document.getElementById('image');
const originalImageData = imgElement.dataset.originalImage;
const naturalWidth = parseInt(imgElement.dataset.naturalWidth, 10);
const naturalHeight = parseInt(imgElement.dataset.naturalHeight, 10);
// Get the image's bounding rectangle
const imgRect = imgElement.getBoundingClientRect();
// Calculate the scaling factors
const scaleX = naturalWidth / imgElement.width;
const scaleY = naturalHeight / imgElement.height;
// Convert displayed coordinates to original image coordinates
const sx = cropStartX * scaleX;
const sy = cropStartY * scaleY;
const ex = cropEndX * scaleX;
const ey = cropEndY * scaleY;
const width = ex - sx;
const height = ey - sy;
// Ensure the crop area is within image bounds
const adjustedStartX = Math.max(0, Math.min(sx, naturalWidth - width));
const adjustedStartY = Math.max(0, Math.min(sy, naturalHeight - height));
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
// Set canvas dimensions
canvas.width = cropEndX - cropStartX;
canvas.height = cropEndY - cropStartY;
// Create a temporary image object to load the original image data
const tempImg = new Image();
tempImg.onload = function() {
// Draw the cropped area on the canvas
ctx.drawImage(tempImg, adjustedStartX, adjustedStartY, width, height, 0, 0, canvas.width, canvas.height);
const croppedImgElement = document.getElementById('croppedImage');
croppedImgElement.src = canvas.toDataURL();
croppedImgElement.style.display = 'block';
};
tempImg.src = originalImageData;
document.getElementById('predict').disabled = false; // 启用预测按钮
document.getElementById('promptText').innerText = ''
}
预测图像
predictImage()
:使用加载的模型对裁剪后的图像进行预测。获取图像元素并将其转换为TensorFlow张量。调整图像大小,归一化并添加批量维度后进行预测。解析预测结果并显示前k个预测结果。
async function predictImage() {
if (!model) {
document.getElementById('result').innerText = "模型尚未加载。";
return;
}
if (!isSelectRect) {
document.getElementById('result').innerText = "请拖动鼠标,框选花朵。";
return;
}
// 获取图像元素
const imgElement = document.getElementById('croppedImage');
// 从图像元素创建张量
const tensorImg = tf.browser.fromPixels(imgElement).toFloat();
// 调整为模型需要的输入大小
const resizedImg = tf.image.resizeBilinear(tensorImg, [224, 224]);
// 归一化图像
const normalizedImg = resizedImg.div(255.0);
// 添加批量维度
const batchedImg = normalizedImg.expandDims(0);
// 进行预测
let predictions;
try {
predictions = await model.execute(batchedImg);
} catch (error) {
console.error("模型预测失败:", error);
document.getElementById('result').innerText = "模型预测失败。";
return;
}
// 检查 predictions 是否有效
if (!predictions || Array.isArray(predictions) && predictions.length === 0) {
console.error("模型预测返回了无效的输出。");
document.getElementById('result').innerText = "模型预测返回了无效的输出。";
return;
}
// 获取第一个 Tensor 作为输出
// 实际不是数组
const outputTensor = Array.isArray(predictions) ? predictions[0] : predictions;
if (!outputTensor) {
console.error("输出 Tensor 未定义");
document.getElementById('result').innerText = "输出 Tensor 未定义。";
return;
}
// 将 Tensor 转换为数组
let probabilities;
try {
probabilities = await outputTensor.data(); // 使用 .data() 而不是 .array()
// console.log("probabilities:")
// console.log(probabilities)
} catch (error) {
console.error("Tensor 转换为数组失败:", error);
document.getElementById('result').innerText = "Tensor 转换为数组失败。";
return;
}
// 获取 top-k 预测
const topKIndices = Array.from(probabilities)
.map((prob, index) => ({prob, index}))
.sort((a, b) => b.prob - a.prob)
.slice(0, topK)
.map(item => item.index);
const topKProbabilities = topKIndices.map(index => probabilities[index]);
// console.log(topKIndices)
// console.log(topKProbabilities)
// 显示 top-k 预测结果, label编号从1开始(index+1),不是从0开始
let resultText = "";
topKIndices.forEach((index, i) => {
const className = labelMap[index + 1] || "Unknown"; // 从字典中获取类别名称
const cnName = leabelEnCn[className];
const probability = topKProbabilities[i];
resultText += `${className}-${cnName}: ${probability.toFixed(4)}\n`; // 保留四位小数
});
document.getElementById('result').innerText = resultText;
}