目录
6.指标计算eval.py ssim psnr mae mse等
1.数据集
1.1文件夹组织形式
1.2 修改代码路径
参数修改指向数据集路径,可根据需要修改迭代次数(我设置了10次迭代,修复效果很一般,我的数据集大小为1.6w张,可以视情况修改)、保存间隔等
2.环境requirements
2.1:官方使用
2.2:我用的环境配置:
由于torch和torchvision版本过老我使用的是python=3.9 cuda=117 torch=1.3.1
注意:numpy的版本为1开头的!!!不然会报错
我将numpy降级
conda install numpy==1.24.3
以及重新安装依赖
conda install pillow matplotlib
3.成功运行
4.注意事项及路径等参数的修改
4.1.数据集格式修改:
在places2.py文件中,可以将.jpg格式的文件转为png(需要的话)
修改了代码可以移除损坏文件和检测读取数量
import random
import torch
from PIL import Image, ImageFile
from glob import glob
import os
# 允许加载不完整图片
ImageFile.LOAD_TRUNCATED_IMAGES = True
class Places2(torch.utils.data.Dataset):
def __init__(self, img_root, mask_root, img_transform, mask_transform, split='train'):
super(Places2, self).__init__()
self.img_transform = img_transform
self.mask_transform = mask_transform
self.split = split
# 确保路径存在
if not os.path.exists(img_root):
raise RuntimeError(f'Image root path does not exist: {img_root}')
if not os.path.exists(mask_root):
raise RuntimeError(f'Mask root path does not exist: {mask_root}')
# 加载数据路径
if split == 'train':
self.paths = glob('{:s}/data_large/**/*.png'.format(img_root), recursive=True)