使用Diffusers库中的 pipeline时的训练参数详解
前言
在云端部署或是本地简单调用Diffusers实现图像生成时,在huggingface diffusers官网,Github官网或是网上各种讲解均未有效给出在Pipeline中可添加的参数。在一行行调试代码的过程中,终于在diffusers源码中找到了对于pipeline中可输入参数的解释,如下:
- prompt (
str
orList[str]
, optional):
The prompt or prompts to guide image generation. If not defined, you need to passprompt_embeds
. - height (
int
, optional, defaults toself.unet.config.sample_size * self.vae_scale_factor
):
The height in pixels of the generated image. - width (
int
, optional, defaults toself.unet.config.sample_size * self.vae_scale_factor
):
The width in pixels of the generated image. - num_inference_steps (
int
, optional, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. - guidance_scale (
float
, optional, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
prompt
at the expense of lower image quality. Guidance scale is enabled whenguidance_scale > 1
. - negative_prompt (
str
orList[str]
, optional):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
passnegative_prompt_embeds
instead. Ignored when not using guidance (guidance_scale < 1
).
num_images_per_prompt (int
, optional, defaults to 1):
The number of images to generate per prompt. - eta (
float
, optional, defaults to 0.0):
Corresponds to parameter eta (η) from the DDIM paper. Only applies to the [~schedulers.DDIMScheduler
], and is ignored in other schedulers. - generator (
torch.Generator
orList[torch.Generator]
, optional):
Atorch.Generator
to make
generation deterministic. - latents (
torch.FloatTensor
, optional):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied randomgenerator
. - prompt_embeds (
torch.FloatTensor
, optional):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from theprompt
input argument. - negative_prompt_embeds (
torch.FloatTensor
, optional):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided,negative_prompt_embeds
are generated from thenegative_prompt
input argument. - output_type (
str
, optional, defaults to"pil"
):
The output format of the generated image. Choose betweenPIL.Image
ornp.array
. - return_dict (
bool
, optional, defaults toTrue
):
Whether or not to return a [~pipelines.stable_diffusion.StableDiffusionPipelineOutput
] instead of a
plain tuple. - callback (
Callable
, optional):
A function that calls everycallback_steps
steps during inference. The function is called with the
following arguments:callback(step: int, timestep: int, latents: torch.FloatTensor)
. - callback_steps (
int
, optional, defaults to 1):
The frequency at which thecallback
function is called. If not specified, the callback is called at
every step. - cross_attention_kwargs (
dict
, optional):
A kwargs dictionary that if specified is passed along to the [AttentionProcessor
] as defined in
self.processor
. - guidance_rescale (
float
, optional, defaults to 0.7):
Guidance rescale factor from Common Diffusion Noise Schedules and Sample Steps are
Flawed. Guidance rescale factor should fix overexposure when
using zero terminal SNR.
在这其中,我们主要用的比较多的是:
- prompt: 正面提示词
- height、width:生成图像的高和宽
- num_inference_steps: 这个很多资料都没有涉及,主要影响了扩散过程中加噪和去噪的部署
- guidance_scale: 文字相关度,这个值越高,生成的图像就跟文本内容越贴近(但不是越大越好,越大生成出来的质量很差)
- negative_prompt: 负面提示词
- num_images_per_prompt: 每次出图的数量
- generator: 生成器相关属性(可以设置出图的种子之类的)
其他的属性就不太常用到了
pipeline文件在:diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 546行左右
实例:
from diffusers import StableDiffusionPipeline
import torc
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe.to("cuda")
image = pipe("An image of a squirrel in Picasso style",height=768,width=768,guidance_scale=7).images[0]
image.save("squirrel.png")