项目简介
在超分辨领域,下载到的数据集只有高分辨图像,低分辨图像需要根据高分辨率图像通过相应的下采样方式获取。本项目使用python脚本实现对高分辨图像的批量下采样。
实现思路
1、获取用户输入参数。包括高分辨图像路径、下采样因子等。
2、获取文件列表,也就是说我们需要知道用户传入的路径中有多少图片,然后记录图片的位置。
3、遍历整理好的文件列表,逐张打开图片。
4、先将原图作为高分辨率图像复制到相应的文件夹。
5、对图片进行下采样。
6、将下采样得到的图片作为低分辨率图像存放在相应文件夹。
完整代码
import cv2
import os
from tqdm import *
from pathlib import Path
import time
import argparse
# 整理文件
def countFile(dir):
tmp = 0 # 用于记录图片数量
f = open(os.path.join(os.getcwd(), "filelist.txt"), mode='a+') # 在当前目录创建文本记录文件位置
for item in os.listdir(dir): # 遍历文件夹
if os.path.isfile(os.path.join(dir, item)): # 如果该项是文件
tmp += 1 # 图片数量加一
f.write(os.path.join(dir, item) + '\n') # 记录文件位置
else:
f.close()
tmp += countFile(os.path.join(dir, item))
f.close()
return tmp
def main(args):
# 1、相关参数设置(已在args中完成)
# 2、获取文件数量
filenum = 0
if args.fileListFlag:
if os.path.exists(os.path.join(os.getcwd(), "filelist.txt")):
os.remove(os.path.join(os.getcwd(), "filelist.txt"))
print("正在整理文件...")
filenum = countFile(args.originalPath) # 返回图片的张数并生成文件列表
print('给定路径中的文件数量:' , filenum)
# 3、遍历打开每幅图片
if filenum != 0:
f = open(os.path.join(os.getcwd(), "filelist.txt"), mode='r+')
pbar = tqdm(total=filenum)
for originalPicPath in f.readlines():
if not originalPicPath=='\n':
originalPicPath = originalPicPath.replace('\n','')
originalPic = cv2.imread(originalPicPath)
# 4、复制原图片到指定文件夹
PicName_HR = os.path.join(args.savePath, "HR", originalPicPath[-18:])
# 如果文件夹不存在就创建文件夹
if not os.path.exists(PicName_HR[:-7]):
os.makedirs(PicName_HR[:-7])
cv2.imwrite(PicName_HR,originalPic)
# 5、对打开的图片进行下采样
rows, cols, channels = originalPic.shape
BicubicPic = cv2.resize(originalPic, (int(cols/args.scale), int(rows/args.scale)), interpolation=cv2.INTER_CUBIC)
# 6、将下采样图片保存到指定文件夹
PicName_LR = os.path.join(args.savePath, "LR_bicubic", originalPicPath[-18:])
# 如果文件夹不存在就创建文件夹
if not os.path.exists(PicName_LR[:-7]):
os.makedirs(PicName_LR[:-7])
cv2.imwrite(PicName_LR,BicubicPic)
pbar.update(1) # 更新进度条
pbar.close() # 关闭进度条
f.close() # 关闭文件
else:
print("该文件夹中没有图片")
if __name__ == "__main__":
# 获取用户传入的参数
description = '搜索给定文件夹中的图片并使用插值方法下采样'
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--fileListFlag', type=bool, default=True, help='是否新建文件列表')
parser.add_argument('--scale', type=int, default=4, help='下采样因子,默认为4')
parser.add_argument('--originalPath', type=str, default=r"D:\BaiduNetdiskDownload\vimeo_septuplet\vimeo_septuplet\sequences", help='图片文件路径')
parser.add_argument('--savePath', type=str, default=os.getcwd(), help='整理图片的保存路径')
args = parser.parse_args()
main(args)
实现效果
下载到的关于视频超分辨的数据集:
vimeo数据集下载地址
整理后的数据集: