可能我是个菜鸡吧,官方keras和tf不会用啊,无奈之下投入了手写生成器的邪教之内,io速率?别问,想哭,不过能跑就ok。
数据集制作也心累的一批,8G小内存如何降的住50+G大数据,哭哭,所以转投了H5邪教,直接把数据封装,自此告别原始数据的水深火热。
不废话,上干货。
H5数据生成器
# -*- coding: utf-8 -*-
import os
import cv2.cv2 as cv2
from SR.pmbloder import ReadBMPFile
import numpy as np
import h5py
def save_h5(h5f,data,target):
shape_list=list(data.shape)
if not h5f.__contains__(target):
shape_list[0]=None
dataset = h5f.create_dataset(target, data=data,maxshape=tuple(shape_list), chunks=True)
return
else:
dataset = h5f[target]
len_old=dataset.shape[0]
len_new=len_old+data.shape[0]
shape_list[0]=len_new
dataset.resize(tuple(shape_list))
dataset[len_old:len_new] = data
train_path = ''#训练集路径
val_path = ''#测试集路径
def load_image(image_path):
merged = #数据获取操作,按需更改
return merged
def data_generate(path,flag="train"):#sr问题示例,按需更改
hrpath = path+'/HR'
lrpath = path+'/LR'
hrlist = os.listdir(hrpath)
lrlist = os.listdir(lrpath)
h5dir = path+'/'+flag+'.h5'
index = list(range(len(hrlist)))
h5f=h5py.File(h5dir)
for i in index[:]:
print(i)
lrdata = load_image(lrpath+'/'+lrlist[i])
hrdata = load_image(hrpath+'/'+hrlist[i])
save_h5(h5f,data=np.array([lrdata]),target='lr')
save_h5(h5f,data=np.array([hrdata]),target='hr')
h5f.close()
if __name__ == '__main__':
data_generate(val_path,"val")
data_generate(train_path,"train")
训练步长计算:
def getdatastep(path,batch_size=8):
hrpath = path+'/HR'
hrlist = os.listdir(hrpath)
steps = len(hrlist) // batch_size
if steps*batch_size<len(hrlist):
steps=steps+1
return steps
数据生成器模板:
def train_generate(H5datapath,batchsize,shuffle=True):
while 1:
fid = h5py.File(H5datapath, 'r')
trainnb = fid['data'].shape[0]
c = [ i for i in range(trainnb//batchsize)]
if shuffle:
random.shuffle(c)
t = trainnb//batchsize
j = 0
y = []
x = []
for i in c:
x = np.array(fid['data'][i*batchsize:(i+1)*batchsize])
y = np.array(fid['label'][i*batchsize:(i+1)*batchsize])
yield(x,y)
if t*batchsize < trainnb:
x = np.array(fid['data'][t*batchsize:trainnb])
y = np.array(fid['label'][t*batchsize:trainnb])
yield(x,y)