【webAI】Tensorflow.js加载预训练的model

环境准备

  • win10
  • python3.6
  • pip install tensorflow
  • pip install tensorflowjs

训练并保存tensorflow模型为saved_model

# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

# 下载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 初始化session
sess = tf.InteractiveSession()

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

# 神经网络参数
n_input = 784
n_node = 256
n_out = 10

x = tf.placeholder(tf.float32, [None, n_input], name="x")
y_ = tf.placeholder(tf.float32, [None, n_out])

# 第一层
W = weight_variable([n_input, n_node])
b = bias_variable([n_node])
layer_h = tf.nn.relu(tf.matmul(x, W) + b)

# 第二层
W_out = bias_variable([n_node, n_out])
b_out = bias_variable([n_out])
y = tf.nn.relu(tf.matmul(layer_h, W_out) + b_out)

softmax = tf.nn.softmax(y, name="softmax")

# LOSS损失函数
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y))

correct_prediction = tf.equal(tf.argmax(softmax, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 训练模型
train_step = tf.train.AdamOptimizer().minimize(cross_entropy)

tf.global_variables_initializer().run()
for i in range(2000):
  batch = mnist.train.next_batch(50)
  if i % 200 == 0:
    train_accuracy = accuracy.eval(feed_dict={
        x: batch[0], y_: batch[1]})
    print('step %d, training accuracy %g' % (i, train_accuracy))
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})

print('test accuracy %g' % accuracy.eval(feed_dict={
    x: mnist.test.images, y_: mnist.test.labels}))

# 保存模型为saved_model
tf.saved_model.simple_save(sess, "./saved_model",
                           inputs={"x": x, }, outputs={"softmax": softmax, })

转换tensorflow的模型

tensorflowjs_converter --input_format=tf_saved_model \
  --output_node_names="softmax" \
  --saved_model_tags=serve ./saved_model \
  ./web_model
  • 转换后的模型文件
  • tensorflowjs_model.pb 为 tensorflow.js能识别的模型
  • weights_manifest.json 为 tensorflow.js能识别的模型参数文件


这里写图片描述


Tensorflow.js加载转换后的模型

import * as tf from '@tensorflow/tfjs'
import {loadFrozenModel} from '@tensorflow/tfjs-converter'

const MODEL_URL = 'tensorflowjs_model.pb'
const WEIGHTS_URL = 'weights_manifest.json'

async function predict() {
    try {
      const model = await loadFrozenModel(MODEL_URL, WEIGHTS_URL)
      var xs = tf.tensor2d([pixels])
      var output = model.execute({x: xs})
      console.log(output.dataSync())
      return output
    } catch (e) {
      console.log(e)
    }
 }
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
加载tflite模型进行图片识别可以通过以下步骤实现: 1. 准备模型文件 首先需要准备好tflite模型文件。可以从TensorFlow官网下载已经训练好的模型,或者自己训练一个模型并转换为tflite格式。 2. 加载模型 使用TensorFlow.js的`tf.lite.loadModel()`方法加载tflite模型文件。 ```javascript const model = await tf.lite.loadModel('model.tflite'); ``` 3. 加载图片 使用JavaScript的`Image`对象或者`HTMLCanvasElement`对象加载需要识别的图片。 ```javascript const image = new Image(); image.src = 'image.jpg'; await image.decode(); const canvas = document.createElement('canvas'); canvas.width = image.width; canvas.height = image.height; const context = canvas.getContext('2d'); context.drawImage(image, 0, 0, image.width, image.height); const imageData = context.getImageData(0, 0, image.width, image.height); ``` 4. 预处理图片数据 将图片数据转换为模型可以接受的格式。通常需要将像素值归一化到0到1之间,并且将图片数据转换为张量。 ```javascript const tensor = tf.browser.fromPixels(imageData) .resizeNearestNeighbor([224, 224]) .toFloat() .sub(255 / 2) .div(255 / 2) .expandDims(); ``` 5. 进行推理 调用模型的`predict()`方法进行推理,并且获取预测结果。 ```javascript const output = model.predict(tensor); const predictions = output.dataSync(); ``` 6. 处理预测结果 根据模型的输出,处理预测结果并进行展示。 ```javascript // 假设模型是一个分类模型,输出是一个长度为1000的数组,每个元素表示一个类别的概率 const topK = 10; // 取前10个概率最大的类别 const topIndices = tf.topk(output, topK).indices.dataSync(); const topProbabilities = tf.topk(output, topK).values.dataSync(); for (let i = 0; i < topIndices.length; i++) { console.log(`类别: ${topIndices[i]}, 概率: ${topProbabilities[i]}`); } ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值