超详细!DALL · E 文生图模型实践指南

最近需要用到 DALL·E的推断功能,在现有开源代码基础上发现还有几个问题需要注意,谨以此篇博客记录之。

我用的源码主要是 https://github.com/borisdayma/dalle-mini 仓库中的Inference pipeline.ipynb 文件。

在这里插入图片描述

运行环境:Ubuntu服务器

⚠️注意:本博客仅涉及 DALL · E 推断,不涉及训练过程。



一、环境配置

建议使用anaconda新建一个dalle环境,然后在该环境中进行相关配置,避免与环境中的其他库产生版本冲突。

使用下述命令新建名为dalle的环境:

conda create -n dalle python==3.8.0

在终端分别运行下述命令,安装所需的python库:

# 安装 dalle运行需要的依赖库(注意版本只能是0.3.25)# Required only for colab environments + GPU
pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 安装 dalle特定的库
pip install dalle-mini
# 安装 VQGAN
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

PS:如果由于网络连接问题无法通过pip命令下载VQGAN,就采取Plan-B:将仓库 https://github.com/patil-suraj/vqgan-jax 下载到服务器并解压,然后使用cd命令将当前目录到对应的仓库下载路径下,在终端运行python setup.py install安装VQGAN即可。


二、模型下载

由于网络连接问题,我采取「事先把模型下载到本地」的策略对模型进行直接调用,首先要明确的一点是,本项目中使用DALL · E 对图像进行编码,使用VQGAN对图像进行解码,所以我们需要分别下载DALL · E 和 VQGAN 两个模型。

DALL · E 模型下载地址:
mini版本:https://huggingface.co/dalle-mini/dalle-mini/tree/main
mega版本:https://huggingface.co/dalle-mini/dalle-mega/tree/main

VQGAN 模型下载地址:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

下载完毕后,将模型部署到服务器,注意保存路径。


三、程序转换

相较于ipynb文件,我个人更加喜欢操作py文件,所以对于给定的ipynb文件,首先使用命令jupyter nbconvert --to script Inference pipeline.ipynb 将其转为同名py文件,该文件的主要内容如下(不含CLIP排序部分),其中模型路径 DALLE_MODEL和VQGAN_REPO 已改为本地路径(就是第二步中两个模型的保存路径),可以看到文件的注释也比较详细。

# dalle-mini
DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)

# Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)

# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )

# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

# Keys are passed to the model on each device to generate unique inference per device.
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

# ## 🖍 Text Prompt
# Our model requires processing prompts.

from dalle_mini import DalleBartProcessor 
# from transformers import AutoProcessor
processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
# Let's define some text prompts
prompts = [
    "sunset over a lake in the mountains",
    "the Eiffel tower landing on the moon",
]
# print(prompts)
# Note: we could use the same prompt multiple times for faster inference.
tokenized_prompts = processor(prompts)
# Finally we replicate the prompts onto each device.
tokenized_prompt = replicate(tokenized_prompts)

# ## 🎨 We generate images using dalle-mini model and decode them with the VQGAN.

# number of predictions per prompt
n_predictions = 8

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0  # 越高,生成的图像越接近 prompt

from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)  #  jax.device_count()=1,returns the number of available jax devices
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    
    for idx, decoded_img in enumerate(decoded_images):
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
... 

四、程序运行

使用命令 python /newdata/SD/inference_dalle-mini.py 运行程序。理想情况下就能够直接得到dalle生成的图像啦!


五、BUG清除指南

由于外部环境因素和一些不当操作,本人在运行该程序过程中还是遇到一些问题,主要有三个,在此将抱错信息与解决方法一并分享给大家。

  • 因网络问题导致特定文件下载失败,报错信息如下:
...
requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words-frequency.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7faae4168460>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/newdata/SD/inference_dalle-mini.py", line 84, in <module>
    processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrained
    return super(PretrainedFromWandbMixin, cls).from_pretrained(
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrained
    return cls(tokenizer, config.normalize_text, config.max_text_length)
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__
    self.text_processor = TextNormalizer()
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__
    self._hashtag_processor = HashtagProcessor()
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__
    #     wiki_word_frequency = hf_hub_download(
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
    return fn(*args, **kwargs)
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_download
    raise LocalEntryNotFoundError(
huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

顺着上面的报错信息,定位到/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py文件的如下内容:

...
class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
		wiki_word_frequency = hf_hub_download(
		    "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
		)
		self._word_cost = (
		    l.split()[0]
		    for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
		)
...

于是问题的根源就在于,程序运行到这里时,没有找到本地的enwiki-words-frequency.txt文件(经检查该文件其实是存在本地的,不知为何没有找到,很迷),于是尝试通过联网从huggingface官网下载,但由于网络状况欠佳,联网失败,于是报错。解决办法如下:

...
class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
		wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt"
		self._word_cost = (
		    l.split()[0]
		    for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
		)
...

也就是将enwiki-words-frequency.txt文件的本地路径直接赋值给wiki_word_frequency变量,其余部份保持不变,问题解决。


  • 因安装不当导致的版本冲突问题
FIx for "Couldn't invoke ptxas --version"

这个错误的产生是不同python库安装时带来的版本冲突导致的,DALLE-mini要求jax和jaxlib版本必须为0.3.25,但是通过pip imstall dalle-mini 命令安装后的jaxlib版本为0.4.13,但使用pip install jaxlib的方式并不能找到0.3.25版本的jaxlib,而且会产生与flax、orbax-checkpoint等其他库的版本不兼容问题……在尝试多种方法合理降低jaxlib版本均失败后,发现答案就在ipynb中……也就是:pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

💡启示:要以官方说明文档为主,可以少走很多弯路!!!


  • 彩蛋:一个非常奇怪的错误:
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/newdata/SD/inference_dalle-mini.py", line 130, in <module>
    decoded_images = p_decode(encoded_images, vqgan_params)
ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512];
  * some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512];
  * some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias'] of type float32[256];
  * some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128];
  * some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256];
  * one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]

后来发现,是因为之前调试的时候不小心把下面这行代码注释掉了……这个bug排得最辛苦,还挺无语的😂

vqgan_params = replicate(vqgan_params)
  • 因版本限制导致显卡无法正常使用
    由于dalle-mini限制jax和jaxlib的版本只能是0.3.25,因此,无法更新这两个包到最新版本,不知是不是因为这个原因会出现如下报错信息:
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
  • 程序运行过程中还有一些警告,由下述警告也可以看出jax是属于tensorflow派别的。(我这个程序没有识别出显卡的存在,导致只能在CPU上运行)
2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']

  0%|          | 0/8 [00:00<?, ?it/s]
/root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "

后记:第一次接触到基于jax框架编写的程序,还挺新鲜的,感觉和pytorch有一些不一样的地方。了解到jax是tensorflow的轻量级版本。上述博客内容中如果有个人理解不当之处,还望各位批评指正!

参考链接

  1. python pathlib中Path 的使用(解决不同操作系统的路径问题)_python pathlib.path-CSDN博客
  2. python - vmap gives inconsistent shape error when trying to calculate gradient per sample - Stack Overflow
  3. https://github.com/google/jax/issues/9933
  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_Meilinger_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值