1、小程序端环境准备
app.json
"plugins": {
"tfjsPlugin": {
"version": "0.2.0",
"provider": "wx6afed118d9e81df9"
}
}
package.json
"dependencies": {
"@tensorflow-models/posenet": "^2.2.2",
"@tensorflow/tfjs-backend-webgl": "3.5.0",
"@tensorflow/tfjs-converter": "3.5.0",
"@tensorflow/tfjs-core": "^3.5.0",
"@tensorflow/tfjs-layers": "^4.22.0",
"fetch-wechat": "^0.0.3",
"lottie-miniprogram": "^1.0.12"
}
终端执行
Microsoft Windows [版本 10.0.19045.6093]
(c) Microsoft Corporation。保留所有权利。
E:\AAASelfProjectGit\myWxProject> npm i
在微信开发者工具中点击工具->构建npm
2、训练模型
python环境
python 3.8.20
protobuf 3.20.3
numpy 1.22.0
tensorflowjs 3.7.0
tensorflow 2.13.0
训练代码 (使用手写数字数据集,keras自带minist)
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)
model.save('D:\\mnist.h5')
训练完成后.h5文件会在D盘,下面代码将.h5转换成tensorflowjs所需要的.json格式
import os
import subprocess
h5_model_path = "D:\\mnist.h5"
output_dir = "D:\\"
os.makedirs(output_dir, exist_ok=True)
# 使用绝对路径调用(需替换为你的实际路径)
command = [
"python",
"D:\\anaconda\\envs\\ckm\\Scripts\\tensorflowjs_converter.exe", # Windows 路径示例
"--input_format=keras",
h5_model_path,
output_dir
]
try:
subprocess.run(command, check=True)
print("转换成功!")
except subprocess.CalledProcessError as e:
print(f"转换失败,错误代码: {e.returncode}")
转换成功后会得到两个文件
将两个文件上传到服务器,通过地址访问.json
3、小程序端代码预测
在js中引入并使用
var fetchWechat = require('fetch-wechat');
var tf = require('@tensorflow/tfjs-core');
var tfl = require('@tensorflow/tfjs-layers');
var webgl = require('@tensorflow/tfjs-backend-webgl');
var plugin = requirePlugin('tfjsPlugin');
Page({
async onReady() {
//加载相机
const camera = wx.createCameraContext(this)
// 加载模型
const net = await this.loadModel()
this.setData({result: 'Loading'})
let count = 0
//每隔10帧获取一张相机捕捉到的图片
const listener = camera.onCameraFrame((frame) => {
count++
if (count === 10) {
if (net) {
//对图片内容进行预测
this.predict(net, frame)
}
count = 0
}
})
listener.start()
},
//加载模型
async loadModel() {
const net = await tfl.loadLayersModel('https://你的服务器域名.com/model.json')
net.summary()
return net
},
async predict(net, frame) {
try {
const x = tf.tidy(() => {
const imgTensor = tf.tensor3d(
new Uint8Array(frame.data),
[frame.height, frame.width, 4]
)
const d = Math.floor((frame.height - frame.width) / 2)
const imgSlice = tf.slice(imgTensor, [d, 0, 0], [frame.width, frame.width, 3])
const imgResize = tf.image.resizeBilinear(imgSlice, [28, 28])
return tf.mean(imgResize, 2) // [28, 28]
})
// 添加批次维度 [1, 28, 28]
const input = tf.reshape(x, [1, ...x.shape])
// 预测并处理结果
const prediction = await net.predict(input)
// 使用tf.topk替代argMax
const {values, indices} = tf.topk(prediction, 1)
const res = indices.dataSync()[0]
this.setData({result: res})
// 释放内存
tf.dispose([x, input, prediction, values, indices])
} catch (error) {
console.error('预测错误:', error)
this.setData({result: 'Error: ' + error.message})
}
}
})
在wxml中展示{{result}}即可看到预测结果
<view class="landscape-container">
<!-- 相机层(横屏适配) -->
<camera device-position="back" resolution="high" frame-size="large" style="width: 100%; height: 100vh;z-index:10;" catch:tap="cameraClick" id="myCamera">
</camera>
<view style="position: absolute;bottom: 100px;z-index: 99;left: 50%;transform: translateX(-50%);font-size: 20px;font-weight: 800;color: white;">
预测结果:{{result}}
</view>
</view>