tensorflow中多分类数据的导入和数据增强扩充的实现

一、多分类数据的导入

在tensorflow的图像预处理方法中有如下的这样一个函数,设置directory文件路径,在这个文件路径下该函数会将其每个子文件夹作为一个类,按照这个函数的读取顺序对类进行编码。

flow_from_directory(
    directory, target_size=(256, 256), color_mode='rgb', classes=None,
    class_mode='categorical', batch_size=32, shuffle=True, seed=None,
    save_to_dir=None, save_prefix='', save_format='png',
    follow_links=False, subset=None, interpolation='nearest'
)

但是经过实验发现他这个读取顺序很难获取,从而无法准确得到类名所对应的编码,总结这个方法只对二分类有效,对于多分类这个函数就无法使用了。
下面自己编写了这样一个方法实现了多分类的需求:

#从文件夹导入图片和标签,顺便缩小图片
def get_files(filename):
    class_train = []
    label_train = []
    for train_class in os.listdir(filename):
        for pic in os.listdir(os.path.join(filename,train_class)):
            img=cv2.imread(filename+'/'+train_class+'/'+pic)
            class_train.append(img)
            label_train.append(train_class)
    image_list=np.array(class_train)
    label_list=np.array(label_train)
    return image_list,label_list

#将标签进行热编码
def one_hot_bm(Y):
    label_list=[]
    for i in range(len(Y)):
        a=np.array([0]*330)
        a[int(Y[i])-1]=1
        label_list.append(a)
    return label_list

(这里假设类的标签是1到330)
这样就可以得到图片数据和对应的类标签独热编码。
后面就可以使用下面这个函数来生成图片迭代器然后直接用于模型训练:

tf.keras.preprocessing.image.NumpyArrayIterator(
    x, y, image_data_generator, batch_size=32, shuffle=False, sample_weight=None,
    seed=None, data_format=None, save_to_dir=None, save_prefix='',
    save_format='png', subset=None, dtype=None
)

二、数据增强扩充的实现

在tensorflow中有这样的一个图像增强器,可以实现大多数的一种图像增强,但是它只是把原来图像进行了一个替代,并没有实现数据的扩充。

tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False, samplewise_center=False,
    featurewise_std_normalization=False, samplewise_std_normalization=False,
    zca_whitening=False, zca_epsilon=1e-06, rotation_range=0, width_shift_range=0.0,
    height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0,
    channel_shift_range=0.0, fill_mode='nearest', cval=0.0,
    horizontal_flip=False, vertical_flip=False, rescale=None,
    preprocessing_function=None, data_format=None, validation_split=0.0, dtype=None
)

但是搭配如下的一个方法只要不设置终止它会一直使用这个图像增强器生成图片,结合使用yield操作就可以实现自己想要的一个数据扩充。

flow(
    x, y=None, batch_size=32, shuffle=True, sample_weight=None, seed=None,
    save_to_dir=None, save_prefix='', save_format='png',
    subset=None
)

数据扩充:

def gent(bes):
    if bes>0:
        t= image_generator.flow(image_list, label_list, batch_size=32)
        bes-=1
        for i in t:
            yield i

传入参数bes调用gent(bes)就可以实现数据的扩充,扩充的是原来的bes倍,返回的就是一个数据迭代器可以直接传入模型进行训练。

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值