AttGAN实验复现 2024

AttnGAN 代码复现 2024

简介

论文地址: https://arxiv.org/pdf/1711.10485.pdf

代码 python2.7(论文源码):https://github.com/taoxugit/AttnGAN

python2.7基本废掉了,不建议

代码 python3:https://github.com/davidstap/AttnGAN

在这里插入图片描述

环境

名称版本备注
NVIDIA-SMI 531.79CUDA Version: 12.1win11
torch2.3.1+cu121最新
torchvision0.18.1+cu121
Python3.10.14
笔记本 40608G 显存

注意 python2.7版本已经无法下载 torch

参考博客:Torch 、torchvision 、Python 版本对应关系

python 依赖

python-dateutil
easydict
pandas
torchfile
nltk
scikit-image
pyyaml

数据集

  1. 鸟类预处理的元数据:https://drive.google.com/file/d/1O_LtUP9sch09QH3s_EBAgLEctBQ5JBSJ/view,保存到 data
  2. 下载鸟类图像数据:http://www.vision.caltech.edu/datasets/cub_200_2011/ 将它们提取到 data/birds/

在这里插入图片描述

在这里插入图片描述

Training

Pre-train DAMSM

运行命令

cd code
python pretrain_DAMSM.py --cfg cfg/DAMSM/bird.yml --gpu 0

配置文件

code/cfg/DAMSM/bird.yml

TRAIN:
    FLAG: True
    NET_E: ''
    BATCH_SIZE: 48  # RuntimeError: CUDA out of memory,可减少batch_size
    MAX_EPOCH: 600  # 训练轮数目
    SNAPSHOT_INTERVAL: 50  # 每训练50轮保存模型

yml 文件中不要写中文,中文会出现读文件错误!

问题1

re_img = transforms.Resize(imsize[i])(img)

IndexError: list index out of range

code/datasets.py

# 旧代码
if i < (cfg.TREE.BRANCH_NUM - 1)

# 新代码 (如果你修改成 TREE.BRANCH_NUM -2 后续会导致很多问题)
if i < len(imsize)

在这里插入图片描述

问题2

IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

code/pretrain_DAMSM.py:提示我们将 loss[0] 改为 loss.item()

在这里插入图片描述

问题3

TypeError: load() missing 1 required positional argument: ‘Loader’

code/miscc/config.py

# 旧代码
yaml_cfg = edict(yaml.load(f))

# 新代码
yaml_cfg = edict(yaml.safe_load(f))

问题4

RuntimeError: masked_fill only supports boolean masks, but got dtype Byte

code/miscc/losses.py

# 旧代码(windows下不适用,但liunx下适用)
data.masked_fill_(masks, -float('inf'))

# 新代码(相反)
data.masked_fill_(masks.bool(), -float('inf'))

问题5

TypeError: pyramid_expand() got an unexpected keyword argument ‘multichannel’

code/miscc/utils.py

# 旧代码
skimage.transform.pyramid_expand(one_map, sigma=20,
                                                     upscale=vis_size // att_sze,
                                                     multichannel=True)

# 新代码
skimage.transform.pyramid_expand(one_map, sigma=20,
                                                     upscale=vis_size // att_sze,
                                                     channel_axis=-1)

问题6

OSError: cannot open resource

code/miscc/utils.py

# 旧代码
fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)

# 新代码
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, 'FreeMono.ttf')
fnt = ImageFont.truetype(font_path, 50)

eval/FreeMono.ttf 复制到 code/miscc/FreeMono.ttf

在这里插入图片描述

Train AttnGAN

运行命令

python main.py --cfg cfg/bird_attn2.yml --gpu 0

预训练模型下载

如果没有进行 Pre-train DAMSM 步骤,可直接下载 DAMSM 预训练模型 保存到 DAMSMencoders

在这里插入图片描述

配置文件

code/cfg/bird_attn2.yml

TRAIN:
    FLAG: True
	...
    NET_E: '../DAMSMencoders/bird/text_encoder200.pth'  # DAMSM 预训练模型存放位置,可自定义

注意:配置文件中不要进行中文注释

问题1

FileNotFoundError: [Errno 2] No such file or directory: '../DAMSMencoders/bird/image_encoder200.pth'

找不到.pth文件,比对路径发现是 …/DAMSMencoders/birds/image_encoder200.pth

  • 修改后

在这里插入图片描述

问题2

AttributeError: ‘_MultiProcessingDataLoaderIter’ object has no attribute ‘next’

code/trainer.py

# 旧代码
data = data_iter.next()

# 新代码
data = next(data_iter)

问题3

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU

code/cfg/bird_attn2.yml

# 旧代码
BATCH_SIZE: 20  # 22

# 新代码
BATCH_SIZE: 10  # 20 22

警告

UserWarning: This overload of add_ is deprecated

code/trainer.py

# 修改前
avg_p.mul_(0.999).add_(0.001, p.data)

# 修改后
avg_p.mul_(0.999).add_(p.data, alpha=0.001)

注意:警告可以不进行修改

如果你不想进行Trian AttnGAN 步骤,可以下载已训练好的 AttnGAN 模型 保存到 models/

在这里插入图片描述

Sampling

运行命令

python main.py --cfg cfg/eval_bird.yml --gpu 0

B_VALIDATION 为 False (默认)

./data/birds/example_filenames.txt 中的文本(可自定义句子)作为输入生成样本图像,没啥用还会干扰评价指标测量

配置文件

code/cfg/eval_bird.yml

B_VALIDATION: False

BATCH_SIZE: 100

错误1

Length of all samples has to be greater than 0, but found an element in ‘lengths’ that is <= 0

  • 进行控制变量法
    • data/birds/example_filenames.txt 进行不断删减,检测存在问题的输入样本

正确运行

example_captions
text/180.Wilson_Warbler/Wilson_Warbler_0007_175618
text/180.Wilson_Warbler/Wilson_Warbler_0024_175278
text/180.Wilson_Warbler/Wilson_Warbler_0074_175645
text/180.Wilson_Warbler/Wilson_Warbler_0107_175320
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0001_163813
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0008_164001
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0016_164060
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0035_163587
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0101_164324
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0103_163669

在这里插入图片描述

运行报错

example_captions
text/138.Tree_Swallow/Tree_Swallow_0002_136792
text/138.Tree_Swallow/Tree_Swallow_0008_135352
text/138.Tree_Swallow/Tree_Swallow_0030_134942
text/138.Tree_Swallow/Tree_Swallow_0050_135104
text/138.Tree_Swallow/Tree_Swallow_0117_134925
text/098.Scott_Oriole/Scott_Oriole_0002_795829
text/098.Scott_Oriole/Scott_Oriole_0014_795827
text/098.Scott_Oriole/Scott_Oriole_0018_795840
text/098.Scott_Oriole/Scott_Oriole_0046_92371

在这里插入图片描述

反手查看源文件

data/birds/text/138.Tree_Swallow/Tree_Swallow_0030_134942.txt

the blue backed, white bellied baby bird has a very fat little belly
this is a bird with a white belly and breast and a blue back and head.
the bird has a small white body with blue and green colored crown and coverts.
a small round bird with a white and blue body.
this��bird��has��a��white��belly,��dark��blue��wings,��and��a��small,��short��bill.
this bird is blue with white and has a very short beak.
this is a blue bird with a white throat, breast, belly and abdomen and a small black pointed beak
this small bird has a white beast and blue crest and back.
this bird has wings that are blue and has a white belly
this bird has wings that are blue and has a white belly

偷真的🐕,将��替换成空格后正常运行!!!

B_VALIDATION 为 True

data/birds/text 中文本作为输入,生成样本图像,用于评价指标测量

code/cfg/eval_bird.yml

# 新代码
B_VALIDATION: True

# 新代码(8G显存)
BATCH_SIZE: 20

运行命令

python main.py --cfg cfg/eval_bird.yml --gpu 0

结果

Make a new folder:  ../models/bird_AttnGAN2/valid/single/156.White_eyed_Vireo
step:  100
Total time for training: 91.08317875862122

在这里插入图片描述

参考博客

AttnGAN代码复现(详细步骤+避坑指南)文本生成图像

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值