背景介绍
这两天学习attention机制,看到了这篇文章,就看了看文章和他的代码,代码有些地方没有注明,使用起来难度较大,故在这里写一下这篇文章的代码复现。话不多说,开始展示。
论文信息
代码地址:github地址
论文地址:arxiv地址
ECSSD数据集:ECSSD dataset 地址
参考链接:github地址
参考链接:github地址
实现难点
- 缺少两个文件(可能是我没看懂作者的代码,比如cvpr2019——PFA里 get_list.py 可能是产生 train_pair.txt 的,但我确实没操作出来;
- 部分语句有点问题;
- python3和python2 的print语句差别
- 数据集不完整
PS(我解决了这些我认为是问题的问题,可能作者在文档中有相关文件,可能是我没看到,我只是按照我的方法实现了结果复现)
问题1 缺少两个文件
从github上下载代码,点击Download ZIP
解压出来
文件目录大致如下
这里呢,相对于下载下来的目录是多了两个文件的,换句话说,下载下来的代码中是没有以下两个文件的,用下面的代码新建一个generate_train_file.py文件即可。
- generate_train_file.py # 用于生成train_pair.txt
- train_pair.txt # 用于读入数据
generate_train_file.py 代码
generate_train_file.py 用于生成指定数据集的 train_pair.txt. 因为train.py通过调用 data.py 里面的getTrainGenerator函数实现对数据的遍历,所以针对不同的数据集,需要生成不同的train_pair.txt
import os
dataset_root = "cvpr2019_PFA/ECSSD"#此处使用的是ECSSD数据集,也可以换成其他的
img_list = []
def check_num_images():
jpg_count = 0
gt_count = 0
for root, dirs, files in os.walk(dataset_root):
for fname in files:
if 'jpg' in fname:
jpg_count+=1
img_list.append(fname[:-4])
if 'png' in fname:
gt_count+=1
print ("num of images: {}, num of GT maps: {}".format(jpg_count, gt_count))
check_num_images()
with open("train_pair.txt", 'w+') as fout:
for img in img_list:
img_path = os.path.join(dataset_root, img)
fout.write(f'{
img_path}.jpg {
img_path}.png\n')