Keras:关于fit_generator中,yeild到底接收的是什么类型的值?

本文是关于fit_generator函数中的generator属性中的yield内容。

什么?yield是什么?我也不知道yield呀,生成器呀什么的[img/眼神闪躲],戳我,我跟你简单喷喷fit_generator的好处,包含官方文档,可详细了解

fit_generator(generator=myGenerator(x_train, y_train, batch_size=50),
           	  steps_per_epoch=400,
           	  epochs=30,
           	  validation_data=myGenerator(x_valid, y_valid,  batch_size=50),
              validation_steps=50)

本人小白。搞了一下午。不停的出错,时而是传进去的数据类型不正确。时而是传进去了,在执行训练的时候,维度又出现bug。说实在的。虽然是小白,但是在np.array、list、tensor对象之间进行转换我还是会的。所以,我的主要问题就是:我不知道我要转换成什么类型,然后送给yield
弄了大半天。毫无进展,晚上吃完饭回来,把generator的内容全部推翻。尝试重新写!我就不信了。

盲猜

猜测一yield要的imagelabel肯定都是np.array类型(猜的)
首先:yield要的肯定是image和与之对应的label(假设每次取出50个数据。即batch_size=50)
其次:这个label标签必须已经做过了one-hot

好了,猜完了。。。。。。

那就开始做吧:
本文数据层次组织:

----- train
----------------------class_1
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3

--------------------------------pic_n
----------------------class_2
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3

--------------------------------pic_n
----------------------class_n
--------------------------------pic_1
--------------------------------pic_2
--------------------------------pic_3

--------------------------------pic_n
其他文件夹结构同上

步骤
(1)、获取所有的label标签和image [一一对应]

org_train_path = "D:/1/XiongAnDatasets/AID_1/img_all1/train"
org_valid_path = "D:/1/XiongAnDatasets/AID_1/img_all1/valid"
org_test_path = "D:/1/XiongAnDatasets/AID_1/img_all1/test"

# 需要的识别类型
classes = {'Bridge': 0, 'Meadow': 1, 'River': 2, 'Mountain': 3,
           'Beach': 4, 'Farmland': 5, 'Forest': 6}

# region 读取数据(顺序经此已经打乱)
def get_files(orig_picture):
    class_train = []
    label_train = []
    for index, name in enumerate(classes):
        print(index, name)
        class_path = orig_picture + '/' + name
        for pic in os.listdir(class_path):
            class_train.append(class_path + '/' + pic)
            label_train.append(index)
    temp = np.array([class_train, label_train])
    temp = temp.transpose()
    # shuffle the samples
    np.random.shuffle(temp)
    # after transpose, images is in dimension 0 and label in dimension 1
    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    return image_list, label_list
# endregion
def img2array(img_list):
    pre_x = []
    for i in range(len(img_list)):
        img = cv2.imread(img_list[i])
        img_resize = cv2.resize(img, (height, width))
        new_img = cv2.cvtColor(img_resize, cv2.COLOR_BGR2RGB)
        pre_x.append(new_img)  # input一张图片
    pre_x = np.array(pre_x) / 127.5 - 1.0
    return pre_x

(2)、调用上述代码并将得到的labelimageone-hotnp.array


from keras.utils import to_categorical

# region 划分并打乱数据
    x_train, y_train = get_files(org_train_path)
    # list------>np.array--------->one-hot(<-----<------<----看箭头-----------)
    y_train_one_hot = to_categorical(np.array(y_train))  # 引入头文件
    # list------>np.array(<--------<-----<------<----<-------看箭头-----------)
    x_train_new = img2array(x_train)
 # endregion

(3)、按照batch_size进行逐步送入generator

def myGenerator(X_img, Y_label, batch_size=50):
    # 传入的x_img的类型和Y_label是ndarray类型,
    # Y_label是已经转one-hot的了
    assert len(X_img) == len(Y_label)
    total_size = len(X_img)
    while 1:
        for i in range(int(total_size / batch_size)):
            yield X_img[i * batch_size:(i + 1) * batch_size], Y_label[i * batch_size:(i + 1) * batch_size]

    return myGenerator

(4)、训练调用

# 开始训练网络模型
history = model.fit_generator(generator=myGenerator(x_train_new,
													y_train_one_hot),
                              steps_per_epoch=400,
                              epochs=30,    
                              validation_data=myGenerator(x_valid_new,
                              							  y_valid_one_hot),
                              validation_steps=50)

(5)成。。。。。成。。。。。成功。。。。了。。。

总结:还是先声明:本人小白。以上代码虽然成功了。但是我并没有做到阻止内存爆炸的行为。执行时,还是会把全部数据一起加载到内存后,再进行训练。这篇文章的主题主要还是要去了解yield接收的数据的类型。如果有哪位大神在人群中多看了俺一眼,关于内存爆炸的问题,还请不吝赐教。感谢!

如有错误请指出。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值