前言
TensorFlow.js中加载预训练模型进行图片分类预测。
项目结构:
data:
|----mobileNet:
|----|----web_model:
|----|----|----group1-shard1of1.bin
|----|----|----model.json
mobileNet:
|----imagenet_classes.js
|----index.html
|----script.js
|--- untils.js
一、加载MobileNet模型
创建index.html程序入口文件,编写script标签跳转到script.js。
<script src="script.js"></script>
在script.js中编写程序主要代码。
加载预训练模型可以翻墙直接加载国外该模型的地址,如果没办法翻墙可以先把模型下载保存在本地,在Vscode中启动一个静态服务器,从静态服务器中加载。
启动静态服务器。我们的模型在data文件夹下面,所以在data中启动一个静态服务器,默认端口8080。
http-server data --cors
Available on:
http://192.168.4.167:8080
http://127.0.0.1:8080
Hit CTRL-C to stop the server
接着在script.js文件中加载mobilenet模型。
import * as tf from "@tensorflow/tfjs"
// 定义mobilenet模型地址
const MOBILENET_MODEL_PATH = "http://127.0.0.1:8080/mobilenet/web_model/model.json"
window.onload = async() => {
// 加载预训练模型
const model = tf.loadLayersModel(MOBILENET_MODEL_PATH);
};
二、编写前端界面输入带预测结果
在index.html中编写上传文件的功能。
<input type="file" onchange="predict(this.files[0])">
当上传文件就会触发predict方法进行预测。
新建一个文件utils.js,编写将file转成img的方法。
export function file2img(f){
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) =>{
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
在script.js中将该方法引入。
import {file2img} from "./utils";
获取上传的图片,转换成img格式之后再转成tensor,并对其进行归一化操作。
window.predict = async(file) => {
// 将加载的图片文件转换成img格式
const img = await file2img(file);
const pred = tf.tidy(() => {
const input = tf.browser.fromPixels(img)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return model.predict(input);
});
三、使用训练好的模型进行预测
// 预测
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${IMAGENET_CLASSES[index]}`);
}, 0);
四、完整代码
index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
script.js
import * as tf from "@tensorflow/tfjs"
import { IMAGENET_CLASSES } from './imagenet_classes';
import {file2img} from "./utils";
// 定义mobilenet模型地址
const MOBILENET_MODEL_PATH = "http://127.0.0.1:8080/mobilenet/web_model/model.json"
window.onload = async() => {
// 加载预训练模型
const model = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
window.predict = async(file) => {
// 将加载的图片文件转换成img格式
const img = await file2img(file);
document.body.appendChild(img)
const pred = tf.tidy(() => {
const input = tf.browser.fromPixels(img)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return model.predict(input);
});
// 预测
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${IMAGENET_CLASSES[index]}`);
document.body.removeChild(img)
}, 0);
}
};
utils.js
export function file2img(f){
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) =>{
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}