Slim Load VGG net

import sys
import os
sys.path.append("C:\\Users\\kuiw\\Desktop\\workspace\\python\\tf\\models\\research\\slim")

from datasets import dataset_utils
from nets import vgg
from preprocessing import vgg_preprocessing
from datasets import imagenet
import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2

import numpy as np
#MODEL_DIR = ".\\models\\vgg_19_2016_08_28\\"
#MDDEL_URL = "http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz"
MODEL_DIR = ".\\models\\vgg_16_2016_08_28\\"
MDDEL_URL = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"

class SlimVGG:
    def __init__(self):
        if not tf.gfile.Exists(MODEL_DIR):
            print (MODEL_DIR, "is not exited")
            
            tf.gfile.MakeDirs(MODEL_DIR)
            dataset_utils.download_and_uncompress_tarball(MDDEL_URL, MODEL_DIR)

        else:
            print ("vgg model is ready!")
        
    def Evaluation(self):
        #vgg = tf.contrib.slim.nets.vgg
        img_size = vgg.vgg_19.default_image_size
        #imgPath="./test_data/school.jpg"
        imgPath="./test_data/tiger.jpeg"
        '''
        image = tf.read_file(imgPath)
        img = tf.image.decode_jpeg(image, channels=3)
        '''
        img = tf.placeholder(tf.float32, shape=[img_size,img_size,3], name='input_image')
        preprocessed_img = vgg_preprocessing.preprocess_image(img,img_size,img_size,is_training=False)
        preprocessed_img = tf.expand_dims(preprocessed_img, 0)
        with slim.arg_scope(vgg.vgg_arg_scope()):
            logits,endpoint = vgg.vgg_16(preprocessed_img,num_classes=1000,is_training=False)
        probs = tf.nn.softmax(logits)
        #load_model = slim.assign_from_checkpoint_fn(os.path.join(MODEL_DIR,'vgg_19.ckpt'), slim.get_model_variables('vgg_19'))
        load_model = slim.assign_from_checkpoint_fn(os.path.join(MODEL_DIR,'vgg_16.ckpt'),
                                                    slim.get_model_variables('vgg_16'))
        
        with tf.Session() as sess:
            load_model(sess)

            eva_img = cv2.imread(imgPath)
            eva_img = eva_img[:,:,::-1]
            eva_img = cv2.resize(eva_img,(img_size,img_size),interpolation=cv2.INTER_CUBIC)
            eva_img = eva_img.astype(np.float32)

            feed_dict = {img:eva_img}
            
            fc8_output,endout,probout = sess.run([logits,endpoint,probs],feed_dict = feed_dict)
            top5_prob = np.argsort(probout[0,:])[::-1][0:5]
            names = imagenet.create_readable_names_for_imagenet_labels()

            for i in top5_prob:
                print ("classify name:",names[i+1],",probability:",probout[0,i])
            # output other layer result
            #print (endout['vgg_19/conv5/conv5_3'])

if __name__ == "__main__":
    svgg = SlimVGG()
    svgg.Evaluation()
    

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值