深度学习系列31:Dalle生成模型

1. Dalle模型

前面介绍过VAVQE模型,它本质上是一个encoder-decoder模型,只是中间加了一个codebook。这样我们就可以把尺寸大大缩小。
得到codebook后,图片可以用其进行编码,然后使用自回归模型(比如transformer)来进行序列生成。Taming Transformer就是这样的一个模型。与之相对应的,是早起的PixelCNN、PixelRNN等直接在像素级别进行序列预测的模型,只能处理32*32这样的尺寸。
Dalle模型和Taming Transformer基本相同,只是把输入把文字tokens拼接到了图片tokens前面。
在这里插入图片描述

2. 模型训练代码

先安装:pip install dalle-pytorch
伪代码如下:
1)训练VAE的codebook

import torch
from dalle_pytorch import DiscreteVAE
vae = DiscreteVAE()
loss = vae(images, return_loss = True)
loss.backward()

这步可以跳过,直接使用OpenAI现成的VAE模型:

from dalle_pytorch import OpenAIDiscreteVAE
vae = OpenAIDiscreteVAE() 

或者用Taming Transformer中预训练的VQGAN VAE:

from dalle_pytorch import VQGanVAE
vae = VQGanVAE()

2)训练dalle模型

import torch
from dalle_pytorch import DALLE
dalle = DALLE(vae = vae)
loss = dalle(text, images, return_loss = True)
loss.backward()

3)生成图片

images = dalle.generate_images(text)
# or images = dalle.generate_images(
    text, img = img_prime,num_init_img_tokens = (14 * 32) )

3. 预测部分代码

网上有训练好的模型:https://github.com/robvanvolt/DALLE-models
然后执行:

python generate.py --dalle_path=模型路径 --taming --text=文本内容 --num_images=1 --batch_size=1 --outputs_dir=输出地址

参考这篇https://github.com/rom1504/dalle-service可以部署网页服务,或者在jupyter中执行:
在这里插入图片描述

  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值