from tensorflow.python.keras.applications.vgg16 import VGG16,preprocess_input,decode_predictions
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
def predict():
model = VGG16()
print(model.summary())
#预测一张图片的类别
#加载图片并输入到模型当中
#(224,224)是VGG的输入要求
image = load_img("./tiger.png",target_size=(224,224))
image = img_to_array(image)
#输入到卷积神经网络当中,需要四维结构
image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2]))
print(image.shape)
#预测之前做图片的数据处理,归一化处理等
image = preprocess_input(image)
y_predictions = model.predict(image)
label = decode_predictions(y_predictions)
print(label)
if __name__ == '__main__':
predict()