这段代码定义了一个名为 Zero123PlusPipeline
的类,它继承自 diffusers.StableDiffusionPipeline
。这个类实现了一个基于扩散模型的图像生成管道,结合了多个组件如VAE、文本编码器、图像编码器、UNet、调度器等。下面是对每个部分的详细讲解:
类的定义和属性
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
tokenizer: transformers.CLIPTokenizer
text_encoder: transformers.CLIPTextModel
vision_encoder: transformers.CLIPVisionModelWithProjection
feature_extractor_clip: transformers.CLIPImageProcessor
unet: UNet2DConditionModel
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
vae: AutoencoderKL
ramping: nn.Linear
feature_extractor_vae: transformers.CLIPImageProcessor
depth_transforms_multi = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
这些是类中的属性声明,包括:
tokenizer
,text_encoder
,vision_encoder
: 分别为文本编码器和视觉编码器的组件。feature_extractor_clip
,feature_extractor_vae
: 特征提取器,用于处理输入图像。unet
: UNet模型,用于图像生成。scheduler
: 调度器,用于控制扩散过程。vae
: 变分自编码器(VAE),用于图像编码和解码。ramping
: 线性层,用于调整编码特征。depth_transforms_multi
: 用于深度图像的预处理转换。
初始化方法
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vision_encoder: transformers.CLIPVisionModelWithProjection,
feature_extractor_clip: CLIPImageProcessor,
feature_extractor_vae: CLIPImageProcessor,
ramping_coefficients: Optional[list] = None,
safety_checker=None,
):
DiffusionPipeline.__init__(self)
self.register_modules(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
unet=unet, scheduler=scheduler, safety_checker=None,
vision_encoder=vision_encoder,
feature_extractor_clip=feature_extractor_clip,
feature_extractor_vae=feature_extractor_vae
)
self.register_to_config(ramping_coefficients=ramping_coefficients)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
- 初始化方法接收多个组件作为参数,并调用父类的初始化方法。
- 使用
register_modules
方法注册这些组件。 - 计算
vae_scale_factor
,用于图像处理。 - 初始化
image_processor
,用于后处理图像。
prepare
方法
def prepare(self):
train_sched = DDPMScheduler.from_config(self.scheduler.config)
if isinstance(self.unet, UNet2DConditionModel):
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
- 该方法初始化调度器和UNet模型。
- 如果UNet是条件模型,则将其转换为只参考噪声的UNet,并设置为评估模式。
add_controlnet
方法
def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
self.prepare()
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
- 该方法添加控制网络,并将其集成到UNet模型中。
- 返回一个包含控制网络的超级网络。
encode_condition_image
方法
def encode_condition_image(self, image: torch.Tensor):
image = self.vae.encode(image).latent_dist.sample()
return image
- 该方法对输入图像进行编码,生成潜在表示。
__call__
方法
@torch.no_grad()
def __call__(
self,
image: Image.Image = None,
prompt = "",
*args,
num_images_per_prompt: Optional[int] = 1,
guidance_scale=4.0,
depth_image: Image.Image = None,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
return_dict=True,
**kwargs
):
self.prepare()
if image is None:
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
if depth_image is not None and hasattr(self.unet, "controlnet"):
depth_image = to_rgb_image(depth_image)
depth_image = self.depth_transforms_multi(depth_image).to(
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
)
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
cond_lat = self.encode_condition_image(image)
if guidance_scale > 1:
negative_lat = self.encode_condition_image(torch.zeros_like(image))
cond_lat = torch.cat([negative_lat, cond_lat])
encoded = self.vision_encoder(image_2, output_hidden_states=False)
global_embeds = encoded.image_embeds
global_embeds = global_embeds.unsqueeze(-2)
if hasattr(self, "encode_prompt"):
encoder_hidden_states = self.encode_prompt(
prompt,
self.device,
num_images_per_prompt,
False
)[0]
else:
encoder_hidden_states = self._encode_prompt(
prompt,
self.device,
num_images_per_prompt,
False
)
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
cak = dict(cond_lat=cond_lat)
if hasattr(self.unet, "controlnet"):
cak['control_depth'] = depth_image
latents: torch.Tensor = super().__call__(
None,
*args,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type='latent',
width=width,
height=height,
**kwargs
).images
latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
这是主要的图像生成方法,包含以下步骤:
- 准备模型。
- 检查输入图像是否为空,并将其转换为RGB格式。
- 使用特征提取器处理图像,生成两种格式的图像特征。
- 如果有深度图像,进行处理。
- 将图像转换为设备和数据类型相匹配的格式。
- 对图像进行编码,生成潜在表示。
- 根据
guidance_scale
决定是否进行引导。 - 使用视觉编码器对图像特征进行编码。
- 处理文本提示,生成编码隐藏状态。
- 调整编码隐藏状态,结合全局嵌入。
- 设置条件,调用父类的
__call__
方法进行扩散生成。 - 对生成的潜在表示进行后处理,解码生成最终图像。
- 返回生成的图像。
总结
这个类实现了一个复杂的图像生成管道,结合了多种技术和模型,包括VAE、CLIP编码器、UNet和调度器。通过这种组合,能够在输入图像和文本提示的基础上生成高质量的图像。