StableSR扩散超分辨率模型训练总结

StableSR模型出自论文《Exploiting Diffusion Prior for Real-World Image Super-Resolution》,使用扩散模型做自然界真实影像的超分辨率。其数据增强部分参考Real-ESRGAN工程,因此该模型也可以算是盲超分领域。文章的具体原理可以看论文详细了解,本篇主要介绍模型的训练过程。

一、训练环境配置

我是用anaconda新创建了一个虚拟环境,然后根据作者的环境要求进行了配置,过程挺顺利的,没遇到啥问题,具体如下:

  • Pytorch == 1.12.1
  • CUDA == 11.7
  • pytorch-lightning==1.4.2
  • xformers == 0.0.16 (Optional)
  • Other required packages in environment.yaml
# git clone this repository
git clone https://github.com/IceClear/StableSR.git
cd StableSR

# Create a conda environment and activate it
conda env create --file environment.yaml
conda activate stablesr

# Install xformers
conda install xformers -c xformers/label/dev

# Install taming & clip
pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
pip install -e .

二、预训练模型下载

需要从HuggingFace上下载预训练的Stable Diffusion模型,下载路径如下:

https://huggingface.co/stabilityai/stable-diffusion-2-1-base

点击画圈部分下载,模型比较大,有5个多G

三、训练集准备

我做的是遥感影像的超分,因此就是将高分辨率的遥感影像切分成512x512的patch,放到一个文件夹里。修改配置文件v2-finetune_text_T_512.yaml的gt_path,设置为文件夹路径。

四、训练配置文件修改

修改配置文件v2-finetune_text_T_512.yaml,主要配置ckpt_path的路径,修改为下载的Stable Diffusion预训练模型路径。

其他参数基本不用修改,我是batch_size默认用的6, 3090的卡显存不够,可以修改batch_size和queue_size,适当调整改小一些。

五、阶段一模型训练

主要训练的事Time-aware encoder模型,训练脚本如下:

python main.py --train --base configs/stableSRNew/v2-finetune_text_T_512.yaml --gpus GPU_ID, --name NAME --scale_lr False

如果有多块GPU,可以设置GPU_ID(0,1),--name参数为训练时产生文件的存储文件夹名字,自己设置,可以加一个参数--no-test,这样在训练时不会进行验证,因为我没有准备验证集

整个训练过程是比较耗时的,配置文件设置的迭代次数是80w次,训练过程中会在logs文件夹下存储可视化结果,自己可根据samples生成的结果与gt对比,判断模型的训练程度。作者给的工程也配置了wandb,可以在线看训练过程统计,包含loss之类的,我因为是后台训练,就直接把这个关掉了,在main.py文件中设置offline参数为True。

六、阶段二模型训练

阶段二主要是VQGAN模型的训练,用使用的配置文件是autoencoder_kl_64x64x4_resi.yaml。用第一阶段训练模型的gts数据,用Real-ESRGAN的数据增强方式生成inputs数据,即降质的数据。需要将原来工程中数据增强的代码摘出来单独写一个脚本。如有需要,可以找我提供。

然后需要使用第一阶段训练得到的模型last.ckpt,生成latent。需要用到Stable-SR工程中scripts文件夹中的sr_val_ddpm_text_T_vqganfin_old.py脚本,修改脚本中部分内容,

作者用的还是最初的预训练模型,我用的第一阶段训练的last.ckpt

之前脚本时一次性读入测试图片缓存,训练图像数据量大内存会爆,所以最好改成单张图像读取测试。修改部分如下所示:

配置文件路径

注释掉原来的一次性图像读取模式

改成一次读取一张图像测试

latnet保存位置

保存成npy文件

运行脚本

python scripts/sr_val_ddpm_text_T_vqganfin_old.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt pretrain/last.ckpt --vqgan_ckpt pretrain/last.ckpt --init-img CFW_trainingdata/inputs/ --outdir CFW_trainingdata/samples --ddpm_steps 200 --dec_w 0.0 --colorfix_type adain

注意设置参数--dec_w 0.0,因为这时候VQGAN模型还没训练好。生成好之后依次将gts,inputs,samples,latents放到一个文件夹里。

第二阶段模型训练脚本

python main.py --train --base configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml --gpus GPU_ID, --name NAME --scale_lr False --no-test

迭代合适的次数,看可视化结果判断,得到训练好的VQGAN模型,结合第一阶段模型last.ckpt,使用sr_val_ddpm_text_T_vqganfin_old.py脚本进行推理,在视觉效果上看来,比Real-ESRGAN模型的细节还原确实更好。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值