使用tf-slim的inception_resnet_v2预训练模型进行图像分类

输入是jpg

代码:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 29 16:25:16 2017

@author: wayne
"""

'''
我们用的是tf1.2,最新的tf1.3地址是
https://github.com/tensorflow/models/tree/master/research/slim

http://geek.csdn.net/news/detail/126133
如何用TensorFlow和TF-Slim实现图像分类与分割

https://www.2cto.com/kf/201706/649266.html
【Tensorflow】辅助工具篇——tensorflow slim(TF-Slim)介绍

https://stackoverflow.com/questions/39582703/using-pre-trained-inception-resnet-v2-with-tensorflow
The Inception networks expect the input image to have color channels scaled from [-1, 1]. As seen here.
You could either use the existing preprocessing, or in your example just scale the images yourself: im = 2*(im/255.0)-1.0 before feeding them to the network.
Without scaling the input [0-255] is much larger than the network expects and the biases all work to very strongly predict category 918 (comic books).
'''

import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
from inception_resnet_v2 import *
import numpy as np
import inception_preprocessing
import matplotlib.pyplot as plt
import imagenet  #注意需要用最新版tf中的对应文件,否则http地址是不对的

tf.reset_default_graph()

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'
image = tf.image.decode_jpeg(tf.read_file('dog.jpeg'), channels=3) #['dog.jpg', 'panda.jpg']

image_size = inception_resnet_v2.default_image_size #  299

'''这个函数做了裁剪,缩放和归一化等'''
processed_image = inception_preprocessing.preprocess_image(image, 
                                                        image_size, 
                                                        image_size,
                                                        is_training=False,)
processed_images  = tf.expand_dims(processed_image, 0)

'''Creates the Inception Resnet V2 model.'''
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
  logits, end_points = inception_resnet_v2(processed_images, is_training=False)   

probabilities = tf.nn.softmax(logits)

saver = tf.train.Saver()


with tf.Session() as sess:
    saver.restore(sess, checkpoint_file)

    #predict_values, logit_values = sess.run([end_points['Predictions'], logits])
    image2, network_inputs, probabilities2 = sess.run([image,
                                                       processed_images,
                                                       probabilities])

    print(network_inputs.shape)
    print(probabilities2.shape)
    probabilities2 = probabilities2[0,:]
    sorted_inds = [i[0] for i in sorted(enumerate(-probabilities2),
                                        key=lambda x:x[1])]    


# 显示下载的图片
plt.figure()
plt.
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值