import sys
import os
sys.path.append("C:\\Users\\kuiw\\Desktop\\workspace\\python\\tf\\models\\research\\slim")
from datasets import dataset_utils
from nets import vgg
from preprocessing import vgg_preprocessing
from datasets import imagenet
import tensorflow as tf
import tensorflow.contrib.slim as slim
import cv2
import numpy as np
#MODEL_DIR = ".\\models\\vgg_19_2016_08_28\\"
#MDDEL_URL = "http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz"
MODEL_DIR = ".\\models\\vgg_16_2016_08_28\\"
MDDEL_URL = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"
class SlimVGG:
def __init__(self):
if not tf.gfile.Exists(MODEL_DIR):
print (MODEL_DIR, "is not exited")
tf.gfile.MakeDirs(MODEL_DIR)
dataset_utils.download_and_uncompress_tarball(MDDEL_URL, MODEL_DIR)
else:
print ("vgg model is ready!")
def Evaluation(self):
#vgg = tf.contrib.slim.nets.vgg
img_size = vgg.vgg_19.default_image_size
#imgPath="./test_data/school.jpg"
imgPath="./test_data/tiger.jpeg"
'''
image = tf.read_file(imgPath)
img = tf.image.decode_jpeg(image, channels=3)
'''
img = tf.placeholder(tf.float32, shape=[img_size,img_size,3], name='input_image')
preprocessed_img = vgg_preprocessing.preprocess_image(img,img_size,img_size,is_training=False)
preprocessed_img = tf.expand_dims(preprocessed_img, 0)
with slim.arg_scope(vgg.vgg_arg_scope()):
logits,endpoint = vgg.vgg_16(preprocessed_img,num_classes=1000,is_training=False)
probs = tf.nn.softmax(logits)
#load_model = slim.assign_from_checkpoint_fn(os.path.join(MODEL_DIR,'vgg_19.ckpt'), slim.get_model_variables('vgg_19'))
load_model = slim.assign_from_checkpoint_fn(os.path.join(MODEL_DIR,'vgg_16.ckpt'),
slim.get_model_variables('vgg_16'))
with tf.Session() as sess:
load_model(sess)
eva_img = cv2.imread(imgPath)
eva_img = eva_img[:,:,::-1]
eva_img = cv2.resize(eva_img,(img_size,img_size),interpolation=cv2.INTER_CUBIC)
eva_img = eva_img.astype(np.float32)
feed_dict = {img:eva_img}
fc8_output,endout,probout = sess.run([logits,endpoint,probs],feed_dict = feed_dict)
top5_prob = np.argsort(probout[0,:])[::-1][0:5]
names = imagenet.create_readable_names_for_imagenet_labels()
for i in top5_prob:
print ("classify name:",names[i+1],",probability:",probout[0,i])
# output other layer result
#print (endout['vgg_19/conv5/conv5_3'])
if __name__ == "__main__":
svgg = SlimVGG()
svgg.Evaluation()
Slim Load VGG net
最新推荐文章于 2023-06-11 17:06:13 发布