keras文件读取

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import os
import PIL
import pathlib
import math
import random
import numpy as np
import shutil
import PIL
# 划出测试图像
def div_train_test(data_dir):
    data_dir = pathlib.Path(data_dir)
    image_count = len(list(data_dir.glob('*/*')))  # 读取文件夹下面所有的文件
    print('一共有有{}张图像'.format(image_count))
    # -----------------------------------------------------------------------
    # 查看文件有多少类,并输出第一个类文件中的第一张图像
    # 且查看第一张图像的size和通道数
    name_list = []
    for item in data_dir.iterdir():
        name_list.append(item.name)
        # print(item.resolve())
        # print(item)

    print('有如下类别{}共{}类'.format(name_list, len(name_list)))
    # --------------------------------------------------------------------------
    # 选出百分之30的图像放入测试文件夹中
    n = len(name_list)
    path_join = ('test')
    if os.path.exists(path_join):
        pass
    else:
        os.mkdir(path_join)

    for i in range(n):
        tar_path = os.path.join(path_join, name_list[i])
        if os.path.exists(tar_path):
            pass
        else:
            os.mkdir(tar_path)
        so_path = os.path.join(data_dir, name_list[i])
        so_path = pathlib.Path(so_path)
        num = len(list(so_path.glob('*/')))
        print(num)
        test_num = math.ceil(num * 0.3)
        rd = random.sample(range(0, num), test_num)
        print(len(rd))
        name = []
        path_join = pathlib.Path(path_join)
        for item in list(so_path.glob('*/')):
            name.append(item.name)
        if so_path.is_dir():
            print('正在从{}拷贝'.format(so_path))
            for j in (rd):
                im = PIL.Image.open(str(so_path) + '\\' + name[j])
                im.save(tar_path + '\\' + name[j])
                os.remove(str(so_path) + '\\' + name[j])  # 转移完后删除原图片


# --------------------------------------------------------------------------




def datapreprocess(data_dir,batch_size,img_height,img_width,data_augment):
    '''

    Args:
        data_dir:文件路径
        batch_size: batch大小
        img_height: 图像高度
        img_width: 图像宽度
        data_augment:是否数据增强,True或False,默认False
    Returns:tain_ds和val_ds分别训练和验证集

    '''
    #读取文件,查看相关的属性
    # --------------------------------------------
    # 如果test下面文件为空则划分出测试集
    flag=os.getcwd()
    flag=os.path.join(flag,'test')
    flag=pathlib.Path(flag)
    if len(list(flag.glob('*/*')))==0:
        div_train_test('train')
    # --------------------------------------------


    print(flag)
    data_dir=pathlib.Path(data_dir)
    image_count=len(list(data_dir.glob('*/*'))) #读取文件夹下面所有的文件
    print('一共有有{}张图像'.format(image_count))
# -----------------------------------------------------------------------
    # 查看文件有多少类,并输出第一个类文件中的第一张图像
    # 且查看第一张图像的size和通道数
    name_list=[]
    for item in data_dir.iterdir():
        name_list.append(item.name)
        # print(item.resolve())
        # print(item)

    print('有如下类别{}共{}类'.format(name_list,len(name_list)))

    ph1_dir=os.path.join(data_dir,name_list[0])
    ph1_dir=pathlib.Path(ph1_dir)
    ph1=list(ph1_dir.glob('*/'))
    p=PIL.Image.open(ph1[0])
    plt.imshow(p)
    plt.colorbar()
    plt.show()
    #查看图像的shape
    p=tf.io.read_file(str(list(data_dir.glob('*/*'))[0]))
    p_tensor=tf.image.decode_image(p)
    print(p_tensor.shape)

# -----------------------------------------------------------------------
    #利用keras.preprocessing来创建数据集
    # 因为默认的color_mode参数是‘rgb’,所以如果是灰度图则需要将color_mode改为‘grayscle’
    if str(data_dir)=='train':
        train_ds = keras.preprocessing.image_dataset_from_directory(
            data_dir, validation_split=0.2, subset='training', seed=111,
            image_size=[img_height, img_width], color_mode='grayscale', batch_size=batch_size)
        print(train_ds)
        val_ds = tf.keras.preprocessing.image_dataset_from_directory(
            data_dir, validation_split=0.2, subset='validation', seed=111,
            image_size=[img_height, img_width], color_mode='grayscale', batch_size=batch_size
        )
        print(val_ds)
    if str(data_dir) == 'test':
        test_ds = keras.preprocessing.image_dataset_from_directory(
            data_dir, image_size=[img_height, img_width], color_mode='grayscale')

    # --------------------------------------------------
    #显示图像的类名和显示图像
    # class_names=train_ds.class_names
    class_names=name_list
    print(class_names) #此方法也可以输出类别名
    plt.figure(figsize=(10, 10))
    if str(data_dir)=='train':
        for images, labels in train_ds.take(1):
            for i in range(9):
                ax = plt.subplot(3, 3, i + 1)
                plt.imshow(images[i].numpy().astype("uint8"))
                plt.title(class_names[labels[i]])
                plt.axis("off")
    # 查看图像的形状
    if str(data_dir)=='train':
        for image_batch, labels_batch in train_ds:
            print(image_batch.shape)
            print(labels_batch.shape)
            break
    # --------------------------------------------------
    #利用prefetch和cache加速数据的读取
    # ---------------------------------------------------
    AUTOTUNE=tf.data.experimental.AUTOTUNE
    if str(data_dir)=='train':
        train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
        val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    if str(data_dir)=='test':
        test_ds=test_ds.cache().prefetch(buffer_size=AUTOTUNE)
    #标准化数据
    normalization_layer = keras.layers.experimental.preprocessing.Rescaling(1. / 255)
    if str(data_dir)=='train':
        train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
        val_ds=val_ds.map(lambda x, y: (normalization_layer(x), y))
    if str(data_dir)=='test':
        test_ds=test_ds.map(lambda x, y: (normalization_layer(x), y))
    # image_batch, labels_batch = next(iter(normalized_ds))
    # first_image = image_batch[0]
    # Notice the pixels values are now in `[0,1]`.
    # print(np.min(first_image), np.max(first_image))

        # plt.figure(figsize=(10,10)) #显示被增强的图像
        # for images,_ in train_ds.take(1):
        #     for i in range(9):
        #         augmented_image=data_augmentation(images)
        #         ax=plt.plot(3,3,i+1)
        #         plt.imshow(augmented_image[0].numpy().astype('uint8'))
        #         plt.show()
    if data_augment:
        # 数据增强处理
        data_augmentation = keras.Sequential([
            # 因为numpy版本问题出错,flip不用
            # keras.layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=[img_height, img_width, p_tensor.shape[2]]),
            keras.layers.experimental.preprocessing.RandomRotation(0.1),
            keras.layers.experimental.preprocessing.RandomZoom(0.1),
        ])
    if str(data_dir)=='train':
        print('当前data_dir is ',data_dir)
        if data_augment:
            return train_ds, val_ds, data_augmentation
        # else:
        #     return train_ds, val_ds
    if str(data_dir) == 'test':
        print('当前data_dir is ', data_dir)
        return test_ds
    # if str(data_dir)=='train':
    #     re=(train_ds,val_ds,data_augmentation)
    #     print('re is ',re)
    #     return re
    # if str(data_dir)=='test':
    #     print('test is ',test_ds)
    #     return test_ds




train,val,augmentation=datapreprocess('train',32,224,224,True)
test=datapreprocess('test',32,224,224,False)
print(train,val)
print(test)





  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值