TensorFlow.js实现商标识别

在VsCode中利用TensorFlow.js结合迁移学习实现商标识别。

一、加载商标数据并可视化

数据保存在data文件夹下面,需要先在data文件夹下创建一个静态服务器,用于加载图片。

http-server data --cors
Available on:
  http://192.168.4.167:8080
  http://127.0.0.1:8080
Hit CTRL-C to stop the server

 编写获取图片的脚本文件。

const IMAGE_SIZE = 224;

const loadImg = (src) => {
    return new Promise(resolve => {
        const img = new Image();
        img.crossOrigin = "anonymous";
        img.src = src;
        img.width = IMAGE_SIZE;
        img.height = IMAGE_SIZE;
        img.onload = () => resolve(img);
    });
};
export const getInputs = async () => {
    const loadImgs = [];
    const labels = [];
    for (let i = 0; i < 30; i += 1) {
        ['android', 'apple', 'windows'].forEach(label => {
            const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
            const img = loadImg(src);
            loadImgs.push(img);
            labels.push([
                label === 'android' ? 1 : 0,
                label === 'apple' ? 1 : 0,
                label === 'windows' ? 1 : 0,
            ]);
        });
    }
    const inputs = await Promise.all(loadImgs);
    return {
        inputs,
        labels,
    };
}

创建index.html文件,作为程序的入口文件,在index.html中利用script标签跳转到script.js文件,在script.js中编写主要代码。

加载图片数据。

import {getInputs} from "./data";

window.onload = async() =>{
    const {inputs, labels} = await getInputs();
    console.log(inputs, labels)

};

利用TensorFlow.js中的tfvis进行可视化。

    import * as tfvis from "@tensorflow/tfjs-vis"

    // 可视化图片
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

每行显示两个,旁边的滚动体可以拉动查看更多图片。

二、定义模型结构

加载MobileNet模型并截断所有的卷积池化操作,生成截断模型。并定义新的全连接层。

    import * as tf from "@tensorflow/tfjs"

    // mobilenet模型存放位置
    const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
    
    // 加载MobileNet模型
    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
    // 查看模型结构
    mobilenet.summary();
    // 截断mobilenet卷积操作
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    // 定义全连接层
    const model = tf.sequential();
    model.add(tf.layers.flatten({
        inputShape: layer.outputShape.slice(1)
    }));
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    // 定义输出层
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));
    // 配置损失函数和优化器
    model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

三、迁移学习下的模型训练

首先先定义一个工具类utils.js,用于处理输入到截断模型(mobilenet)中的数据。

import * as tf from '@tensorflow/tfjs';
// img格式转成tensor
export function img2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}
// 图片文件转成img格式
export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}

训练数据经过截断模型输出,转为可以用于自定义的全连接层的输入数据。

    // 先经过截断模型
    const { xs, ys } = tf.tidy(() => {
        const xs = tf.concat(inputs.map(imgEl =>truncatedMobilenet.predict(img2x(imgEl))));
        const ys = tf.tensor(labels);
        return { xs, ys };
    });
    // 截断模型的输出当成自定义模型的输入
    await model.fit(xs, ys, {
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });

可以看出训练损失值降得非常低,因为采用迁移模型,卷积层的参数是使用别人训练好的,这部分参数的训练结果是非常优秀的。

四、预测 

编写前端页面用于上传带预测图片,就是编写一个上传按钮。

<script src="script.js"></script>

<input type="file" onchange="predict(this.files[0])">

将预测图片先经过mobileNet预测,吐出来的结果再经过自定义模型预测。

    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            const input = truncatedMobilenet.predict(x);
            return model.predict(input);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${BRAND_CLASSES[index]}`);
        }, 0);
    };

 

五、完整代码

index.html

<script src="script.js"></script>

<input type="file" onchange="predict(this.files[0])">

script.js 

import * as tf from "@tensorflow/tfjs"
import * as tfvis from "@tensorflow/tfjs-vis"
import {getInputs} from "./data";
import {img2x, file2img} from "./utils"

// mobilenet模型存放位置
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';

const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];


window.onload = async() =>{
    // 加载图片
    const {inputs, labels} = await getInputs();
    // 可视化图片
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

    // 加载MobileNet模型
    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
    // 查看模型结构
    mobilenet.summary();
    // 截断mobilenet卷积操作
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    // 定义全连接层
    const model = tf.sequential();
    model.add(tf.layers.flatten({
        inputShape: layer.outputShape.slice(1)
    }));
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    // 定义输出层
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));
    // 配置损失函数和优化器
    model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });


    // 先经过截断模型
    const { xs, ys } = tf.tidy(() => {
        const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
        const ys = tf.tensor(labels);
        return { xs, ys };
    });
    // 截断模型的输出当成自定义模型的输入
    await model.fit(xs, ys, {
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });

    // 预测
    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            const input = truncatedMobilenet.predict(x);
            return model.predict(input);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${BRAND_CLASSES[index]}`);
        }, 0);
    };


};

data.js 

const IMAGE_SIZE = 224;

const loadImg = (src) => {
    return new Promise(resolve => {
        const img = new Image();
        img.crossOrigin = "anonymous";
        img.src = src;
        img.width = IMAGE_SIZE;
        img.height = IMAGE_SIZE;
        img.onload = () => resolve(img);
    });
};
export const getInputs = async () => {
    const loadImgs = [];
    const labels = [];
    for (let i = 0; i < 30; i += 1) {
        ['android', 'apple', 'windows'].forEach(label => {
            const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
            const img = loadImg(src);
            loadImgs.push(img);
            labels.push([
                label === 'android' ? 1 : 0,
                label === 'apple' ? 1 : 0,
                label === 'windows' ? 1 : 0,
            ]);
        });
    }
    const inputs = await Promise.all(loadImgs);
    return {
        inputs,
        labels,
    };
}

utils.js 

import * as tf from '@tensorflow/tfjs';

export function img2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}

export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

.Thinking.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值