keras使用Sequence类调用大规模数据集进行训练

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。
下面是我所使用的代码

class SequenceData(Sequence):
    def __init__(self, path, batch_size=32):
        self.path = path
        self.batch_size = batch_size
        f = open(path)
        self.datas = f.readlines()
        self.L = len(self.datas)
        self.index = random.sample(range(self.L), self.L)
    #返回长度,通过len(<你的实例>)调用
    def __len__(self):
        return self.L - self.batch_size
    #即通过索引获取a[0],a[1]这种
    def __getitem__(self, idx):
        batch_indexs = self.index[idx:(idx+self.batch_size)]
        batch_datas = [self.datas[k] for k in batch_indexs]
        img1s,img2s,audios,labels = self.data_generation(batch_datas)
        return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

    def data_generation(self, batch_datas):
        #预处理操作
        return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练
这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。

D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)), 
                    epochs=2, workers=20, #callbacks=[checkpoint],
                    use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))  

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

cccccc1212

这是c币不是人民币,不要充值

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值