原文链接: tfjs vue 风格迁移 展示
上一篇: mobilenet 和 vgg 19 的参数量计算
下一篇: flowers 数据集分类 vgg19 微调网络 保存为pb格式
对android支持有缺陷,主要是安卓的浏览器webgl性能不够,ios支持不错,至少ipad可以跑
下载模型文件
https://github.com/zaidalyafeai/Fast-Style-Transfer-Keras-TF.js
该模型只能接受256*256的图片,并且返回256*256的图片,风格信息在网络中,每次添加风格需要重新生成网络
style3 中文件有缺少的所以只使用了四个
项目结构,将常用函数进行封装放入utils文件中
将模型文件和风格图片放入static目录中
关键两个代码。
tf.fromPixels 传入一个img文档对象,内容可以是base64或者blob,超链接都可以,返回一个tensor,大小是img对象的宽高
tf.toPixels 可以将一个tensor绘制到canvas中,大小可以任意设置,会进行相应的拉伸
async function predict(model, img) {
let img_mat = await preprocess(img)
let ret = await model.predict(img_mat)
return await deprocess(ret)
}
async show(style) {
let img = document.getElementById('img')
let canvas = document.getElementById("mix");
let tensor = tf.fromPixels(img).toFloat()
let ret = await predict(await this.models[style], tensor)
tf.toPixels(ret, canvas)
}
工具函数,只将predict函数暴露出去,传入模型和图片张量返回处理后结果
function deprocess(x) {
return tf.tidy(() => {
const offset = tf.scalar(127.5);
// Normalize the image
const denormalized = x.mul(offset).add(offset).toInt();
const reduced = denormalized.squeeze()
return reduced
})
}
function preprocess(tensor) {
return tf.tidy(() => {
const offset = tf.scalar(127.5);
// Normalize the image
const normalized = tensor.sub(offset).div(offset);
//We add a dimension to get a batch shape
const batched = normalized.expandDims(0)
return batched
})
}
async function predict(model, img) {
let img_mat = await preprocess(img)
let ret = await model.predict(img_mat)
return await deprocess(ret)
}
export {
predict,
}
App.vue
先将模型加载,然后根据点击的风格图片调用不同的风格网络,传递给predict函数计算结果并绘制在canvas中
<template>
<div id="app">
<input type="file" id="upload" @change="change">
<img src="/static/default.jpg" alt="" class="img" id="img">
<canvas id="mix" class="mix"></canvas>
<div>
<img @click="show('style1')" src="/static/style_image/style1.jpg" class="style_image">
<img @click="show('style2')" src="/static/style_image/style2.jpg" class="style_image">
<!--<img @click="show('style3')" src="/static/style_image/style3.jpg" class="style_image">-->
<img @click="show('style4')" src="/static/style_image/style4.jpg" class="style_image">
<img @click="show('style5')" src="/static/style_image/style5.jpg" class="style_image">
</div>
</div>
</template>
<script>
import * as tf from '@tensorflow/tfjs';
import {predict} from './utils/index'
export default {
name: "draw",
data() {
return {
prediction: '?',
models: {}
}
},
methods: {
async change(e) {
let img = document.getElementById('img')
let file = document.getElementById('upload').files[0]
console.log(file)
let url = window.URL.createObjectURL(file);
img.src = url
},
async show(style) {
let img = document.getElementById('img')
let canvas = document.getElementById("mix");
let tensor = tf.fromPixels(img).toFloat()
let ret = await predict(await this.models[style], tensor)
tf.toPixels(ret, canvas)
}
},
async mounted() {
this.models = {
style1: tf.loadModel('./static/style_model/style1/model.json'),
style2: tf.loadModel('./static/style_model/style2/model.json'),
// style3:tf.loadModel('./static/style_model/style3/model.json'),
style4: tf.loadModel('./static/style_model/style4/model.json'),
style5: tf.loadModel('./static/style_model/style5/model.json'),
}
}
}
</script>
<style>
#app {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
margin: 5px;
}
.img {
width: 256px;
height: 256px;
border: 1px solid gray;
margin: 5px;
}
.mix {
width: 256px;
height: 256px;
border: 1px solid gray;
margin: 5px;
}
.style_image {
width: 128px;
height: 128px;
margin: 5px;
}
</style>