自制:keras数据集制作与读取

可能我是个菜鸡吧,官方keras和tf不会用啊,无奈之下投入了手写生成器的邪教之内,io速率?别问,想哭,不过能跑就ok。

数据集制作也心累的一批,8G小内存如何降的住50+G大数据,哭哭,所以转投了H5邪教,直接把数据封装,自此告别原始数据的水深火热。

不废话,上干货。

H5数据生成器

# -*- coding: utf-8 -*-
import os
import cv2.cv2 as cv2
from SR.pmbloder import ReadBMPFile
import numpy as np
import h5py
def save_h5(h5f,data,target):
    shape_list=list(data.shape)
    if not h5f.__contains__(target):
        shape_list[0]=None
        dataset = h5f.create_dataset(target, data=data,maxshape=tuple(shape_list), chunks=True)
        return
    else:
        dataset = h5f[target]
    len_old=dataset.shape[0]
    len_new=len_old+data.shape[0]
    shape_list[0]=len_new
    dataset.resize(tuple(shape_list))
    dataset[len_old:len_new] = data



train_path = ''#训练集路径
val_path = ''#测试集路径
def load_image(image_path):
    merged = #数据获取操作,按需更改
    return merged

def data_generate(path,flag="train"):#sr问题示例,按需更改
    hrpath = path+'/HR'
    lrpath = path+'/LR'
    hrlist = os.listdir(hrpath)
    lrlist = os.listdir(lrpath)
    h5dir = path+'/'+flag+'.h5'
    index = list(range(len(hrlist)))
    h5f=h5py.File(h5dir)
    for i in index[:]:
        print(i)
        lrdata = load_image(lrpath+'/'+lrlist[i])
        hrdata = load_image(hrpath+'/'+hrlist[i])
        save_h5(h5f,data=np.array([lrdata]),target='lr')
        save_h5(h5f,data=np.array([hrdata]),target='hr')
    h5f.close()

if __name__ == '__main__':
    data_generate(val_path,"val")
    data_generate(train_path,"train")

训练步长计算:

def getdatastep(path,batch_size=8):
    hrpath = path+'/HR'
    hrlist = os.listdir(hrpath)
    steps = len(hrlist) // batch_size
    if steps*batch_size<len(hrlist):
        steps=steps+1
    return steps

数据生成器模板:

def train_generate(H5datapath,batchsize,shuffle=True):
    while 1:
        fid = h5py.File(H5datapath, 'r')
        trainnb = fid['data'].shape[0]
        c = [ i for i in range(trainnb//batchsize)]
        if shuffle:
            random.shuffle(c)
        t = trainnb//batchsize
        j = 0
        y = []
        x = []
        for i in c:
            x = np.array(fid['data'][i*batchsize:(i+1)*batchsize])
            y = np.array(fid['label'][i*batchsize:(i+1)*batchsize])
            yield(x,y)
        if t*batchsize < trainnb:
            x = np.array(fid['data'][t*batchsize:trainnb])
            y = np.array(fid['label'][t*batchsize:trainnb])
            yield(x,y)

 

  • 6
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值