大四考研狗一直拖到现在才开始做毕设也是心累(最主要的是还没考上,难受的一匹)。
首先感谢各位大佬提供的核心算法思想:Scale-recurrent Network for Deep Image Deblurring
然后加上大家比较感兴趣的数据集的问题:
链接:https://pan.baidu.com/s/1ehHWPNtEXCZVJQR5vtkZ5g 提取码:cy20
最后我的毕业设计打算基于大佬们的算法思想,使用Pytorch实现文中的尺度递归神经网络,并部署在自己的服务器上,利用Android客户端访问服务器获取去模糊的清晰图像。
现在我们来实现我们的网络搭建部分:
1.数据集的读取
此处我们使用的是和大佬们相同的GOPRO图片数据集,总共8.8G大家有需要的话可以私信我。
首先我们处理一下图片的读取问题。
数据集解压过后文件的内容是这样的:
这里的sharp/blur_gamma/blur三个文件存放清晰度不同的同一张照片,想来也是用来做网络训练的三种类别的图片,而在这之上的编号就显得并不重要了,所以我打算先将文件整合成sharp/blur_gamma/blur三组方便后面提取图片进行训练。
由于直接修改文件结构太过于费时费力,这里我们可以想办法筛选出各个类别的文件路径,把同一类的文件路径存放在一个txt文件中,训练时读取路径获取图片。
代码就直接贴出来了:
import os
import numpy as np
import pandas as pd
Simg_path = r"E:\工作\我的毕设\train"#训练集图片总路径
Limg_path = []#图片各自的具体路径
# Limg_path = np.array(Limg_path)
print(Simg_path)
for (root, dirs, files) in os.walk(Simg_path):
if files:
if files[0].split('.')[-1] == 'png':
for file in files:
#Limg_path = np.append(Limg_path,np.array(os.path.join(root,(file))))
#使用np矩阵,由于此处使用np矩阵消耗资源较大,
#我们可以选择先用list后转np.array
Limg_path.append(os.path.join(root,(file)))#使用list列表
else:
pass
#接着我们将文件按照blur/blur_gamma/sharp分为三组
Limg_path
Limg_path_blur = []
Limg_path_blurgamma = []
Limg_path_sharp = []
i = 0
for path in Limg_path:
if path.split("\\")[-2] == 'blur':
Limg_path_blur.append(path)
elif path.split("\\")[-2] == 'blur_gamma':
Limg_path_blurgamma.append(path)
elif path.split("\\")[-2] == 'sharp':
Limg_path_sharp.append(path)
else:
pass#我们只收集这三类图片路径数据
#将list输出
#blur
NPAimg_path = np.array(Limg_path_blur)
PDimg_path = pd.DataFrame(NPAimg_path)
PDimg_path.to_csv("trainImgblur.csv")
#blur_gamma
NPAimg_path = np.array(Limg_path_blurgamma)
PDimg_path = pd.DataFrame(NPAimg_path)
PDimg_path.to_csv("trainImgblur_gamma.csv")
#sharp
NPAimg_path = np.array(Limg_path_sharp)
PDimg_path = pd.DataFrame(NPAimg_path)
PDimg_path.to_csv("trainImgsharp.csv")
这里本来打算用np.savetxt()函数写入的,但是因为我的路径中有中文,提示无法写入非UTF-8文件。找了好一会也没个聪明点的方法,后面就还是改用pandas写入,果然就成了,这里还是吹一波pandas是真心好用啊。
###(* ̄︶ ̄)###
我们写入的文件大概就长这样
后面要用到的时候读取就好了。
接下来我们说说命令行运行的事。
2.python命令行运行函数和参数输入
这里我们就开始着手制作我们的神经网络了。
这里没太多好解释的,直接上代码好了:
import os
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='deblur arguments')
parser.add_argument('--phase', type=str, default='test', help='determine whether train or test')
parser.add_argument('--datalist', type=str, default='./datalist_gopro.txt', help='training datalist')
parser.add_argument('--model', type=str, default='color', help='model type: [lstm | gray | color]')
parser.add_argument('--batch_size', help='training batch size', type=int, default=16)
parser.add_argument('--epoch', help='training epoch number', type=int, default=4000)
parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate')
parser.add_argument('--gpu', dest='gpu_id', type=str, default='0', help='use gpu or cpu')
parser.add_argument('--height', type=int, default=720,
help='height for the tensorflow placeholder, should be multiples of 16')
parser.add_argument('--width', type=int, default=1280,
help='width for the tensorflow placeholder, should be multiple of 16 for 3 scales')
parser.add_argument('--input_path', type=str, default='./testing_set',
help='input path for testing images')
parser.add_argument('--output_path', type=str, default='./testing_res',
help='output path for testing images')
args = parser.parse_args()
return args
def main():
args = parse_args()
print(args)
if __name__ == '__main__':
main()
这里是参照大佬们项目写的输入参数,备注还是蛮齐全的,大家应该都能读懂。
这里说一下argparse是python自带的命令行参数解析包,可以用来方便地读取命令行参数。它的使用也比较简单。大家出门转转就能找到用法教程的。
大家看看运行结果:
这样我们就可以在运行网络时设置我们的各种参数了。(不过我后面会把参数全改成默认值,别问为啥,问就是懒)
###(* ̄︶ ̄)###