keras 导入本地图片测试集并预测结果

需要的库

from keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import numpy as np
import os
import cv2
from PIL import Image
model = tf.keras.models.load_model('sand_model2') #加载已有模型

方法1 (图片较多,按照字典形式存放)
在这里插入图片描述
ImageDataGenerator方法

data_dir='C:/Users/HP/Desktop/tf_data/sand_train' #输入自己的目录,此目录下应有多个子文件夹
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        data_dir,
        shuffle=False,#测试集不需要打乱
        target_size=(256,256), #图片大小,默认256,256
        batch_size=5, #可任意设置
        class_mode='categorical')
  
#用生成器方式预测
predictions=model.predict_generator(test_generator)
#显示结果
result=[]
for iii in range(predictions.shape[0]):
    result.append(np.argmax(predictions[iii]))
print('result',result)

方法2(图片较少,按照字典形式存放)
直接读入,使用Image.open

data_dir = 'C:/Users/HP/Desktop/tf_data/sand_train2'
def read_data(data_dir):
    datas = [] #存储图片数据
    labelnames=[] #存储标签名(文件夹名)
    labels = [] #存储标签
    fpaths = [] #存储每个图片路径
    co=0 #标签,按文件夹顺序分别设为0,1,2......
    for fname in os.listdir(data_dir):#遍历每个文件夹
        fpath = data_dir+'/'+fname
        labelnames.append(str(fname))#文件夹名即为标签名
        for ffname in os.listdir(fpath):#遍历一个文件夹下每张图片
            try:
                image = Image.open(os.path.join(fpath, ffname))
                im = np.array(image)
                if len(im.shape)==2:
                    image=Image.merge('RGB', (image, image, image))#如果是灰度图,扩展为三通道
                data = np.array(image) / 255.0
                data.astype('float32')
                datas.append(data)
                label = co
                labels.append(label)
                everypath = fpath + '/' + ffname
                fpaths.append(everypath)#记下每张图的路径
            except:
                continue
        co=co+1 #标签+1
    datas = np.array(datas)
    labels = np.array(labels)

    print("shape of datas: {}\tkind of labels: {}".format(datas.shape,labels.shape))
    return fpaths, datas, labels,labelnames

#读取
fpaths, test_images, test_labels,labelnames = read_data(data_dir)

#预测
predictions = model.predict(test_images)

#显示结果
result=[]
for iii in range(predictions.shape[0]):
    result.append(np.argmax(predictions[iii]))
print('result',result)

方法3(图片较少,没有标签,存在一个文件夹里)
在这里插入图片描述

使用cv2.imread

#直接读取图片
def get_image(data_dir, img_cols, img_rows, color_type=3,normalize=True):
        imgs = []
        for fname in os.listdir(data_dir):
            try:
                if color_type == 1:
                    img = cv2.imread(os.path.join(data_dir, fname), 0)

                elif color_type == 3:
                    img = cv2.imread(os.path.join(data_dir, fname))

                    # Reduce size
                resized = cv2.resize(img, (img_cols, img_rows))

                if normalize:
                    resized = resized.astype('float32')
                    resized /= 255
                imgs.append(resized)
            except:
                    continue
        imgs=np.array(imgs)
        print("shape of datas: {}".format(imgs.shape))
        return  imgs

#读取10张测试图片
imatest =get_image('C:/Users/HP/Desktop/tf_data/test/test',256,256)

#预测
predictions = model.predict(imatest)

#显示结果
result=[]
for iii in range(predictions.shape[0]):
    result.append(np.argmax(predictions[iii]))
print('result',result)

方法4(已有所有图片路径和标签)
此处data_dir 为一个列表,存放所有图片路径

def get_image_from_paths(data_dir, img_cols, img_rows, color_type=3, normalize=True):
    imgs = []
    for fname in data_dir:
        try:
            if color_type == 1:
                img = cv2.imread(fname, 0)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            elif color_type == 3:
                img = cv2.imread(fname)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                # Reduce size
            resized = cv2.resize(img, (img_cols, img_rows))

            if normalize:
                resized = resized.astype('float32')
                resized /= 255

            imgs.append(resized)
        except:
            continue
    imgs = np.array(imgs)
    print("shape of datas: {}".format(imgs.shape))
    return imgs
imatest =get_image_from_paths(your_paths,256,256) #读取
#后续与上一方法相同

输出形式

#结果
result [0, 0, 0, 0, 2, 0, 0, 0, 0, 0]
  • 0
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值