笔记(2)中利用tensorflow.js实现了一个经典的机器学习问题——CNN识别手写数字集MNIST。这篇笔记将利用web摄像头识别图像并判断上、下、左、右来玩吃豆人游戏。参考官方示例Transfer learning - Train a neural network to predict from webcam data,修改了部分代码。
1、首先引入已训练好的模型,mobilenet
async function loadMobilenet() {
const mobilenet = await tf.loadModel('./model.json');
const layer = mobilenet.getLayer('conv_pw_13_relu');
return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
}
其中函数返回的tf.model中输入还是mobilenet的原始输入,输出为mobilenet的“conv_pw_13_relu”层。一般而言,因为越靠后所包含的训练信息越多,所以应选择已训练好的模型中越靠后的层。
2、定义摄像