最近需要实现超分辨率问题,但EDSR模型不适合自己的数据集,故重新用keras写了一遍,以下是遇到的坑:
1、批量读取图片
因为内存不够的问题,尝试了h5py,train_on_batch,以及fit_generator三种不同的加载大数据集的方式,最后感觉还是fit_generator好用,生成器代码如下(训练和标签均为图片):
def generator(index_list,path,batch_size):
list_x=[]
list_y=[]
count=0
i = 0
while 1:
f = index_list[i%len(index_list)]
img_path1 = path + 'low/' + f
img_path2 = path + 'high/' + f
img_data = image.load_img(img_path1)
img_label = image.load_img(img_path2)
img_array = image.img_to_array(img_data)
img_array2 = image.img_to_array(img_label)
list_x.append(img_array)
list_y.append(img_array2)
count+=1
i = i+1
if count>=batch_size: #数据记录达到batch_size才返回