代码地址:https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb
论文地址:https://arxiv.org/abs/2208.01626
Cross-Attention Visualization
-
设置随机数生成器以及prompt
g_cpu = torch.Generator().manual_seed(8888) prompts = ["A painting of a squirrel eating a burger"]
-
申明一个注意力图控制器对象
controller = AttentionStore()
类AttentionStore继承于抽象类AttentionControl,需要着重关注以下几个方法:
class AttentionControl(abc.ABC): def __call__(self, attn, is_cross: bool, place_in_unet: str): """ 分两种情况:1.低显存模式->LOW_RESOURCE = True 条件嵌入和无条件嵌入串行进行运算,需要先越过无条件嵌入部分结果,只保留条件嵌入部分的注意力图 2.高显存模式->LOW_RESOURCE = False 条件嵌入和无条件嵌入并行进行运算,直接取后半部分的条件嵌入的注意力图 """ if self.cur_att_layer >= self.num_uncond_att_layers: if LOW_RESOURCE: attn = self.forward(attn, is_cross, place_in_unet) else: # [uncond_embeddings, text_embeddings]*8 h = attn.shape[0] attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) self.cur_att_layer += 1 # 当所有层的注意力图都保存后(即完成了一个时间步的去噪后),对每个层的注意力图进行分别求和 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.between_steps() return attn ... class AttentionStore(AttentionControl): def forward(self, attn, is_cross: bool, place_in_unet: str): """ 保存分辨率小于32*32的注意力图 """ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" # Only the attention maps of layers with resulutions below 32*32 are recorded # attn.shape[1] 4096 1024 256 64 if attn.shape[1] <= 32 ** 2: # avoid memory overhead self.step_store[key].append(attn) return attn def between_steps(self): # 将每一步中各层的参数求和, 为了下一步的平均做准备 if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] self.step_store = self.get_empty_store() def get_average_attention(self): # 对所有时间步的各层注意力图进行加权求平均,得到平均注意力图 # cur_step = 51 0-51一共51步 average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} return average_attention ...
-
运行文生图Pipeline并显示最后的生成结果
image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu)
def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None): if run_baseline: print("w.o. prompt-to-prompt") images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator) print("with prompt-to-prompt") images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE) ptp_utils.view_images(images) return images, x_t
@torch.no_grad() def text2image_ldm_stable( model, prompt: List[str], controller, num_inference_steps: int = 50, guidance_scale: float = 7.5, generator: Optional[torch.Generator] = None, latent: Optional[torch.FloatTensor] = None, low_resource: bool = False, ): register_attention_control(model, controller) height = width = 512 # get the number of prompts batch_size = len(prompt) text_input = model.tokenizer( prompt, padding="max_length", max_length=model.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = model.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] context = [uncond_embeddings, text_embeddings] # If the gpu memory is small, calculate separately, otherwise calculate simultaneously if not low_resource: context = torch.cat(context) latent, latents = init_latent(latent, model, height, width, generator, batch_size) # set timesteps #extra_set_kwargs = {"offset": 1} #model.scheduler.set_timestep(num_inference_steps, **extra_set_kwargs) model.scheduler.set_timesteps(num_inference_steps) # 逐步去噪 for t in tqdm(model.scheduler.timesteps): latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) image = latent2image(model.vae, latents) return image, latent
需要着重关注register_attention_control(model, controller)方法,其主要作用为:修改Unet中每一层CrossAttention类的forward()方法,使得Unet在做注意力计算时可以保存其注意力图的参数。
def register_attention_control(model, controller): def ca_forward(self, place_in_unet): to_out = self.to_out if type(to_out) is torch.nn.modules.container.ModuleList: to_out = self.to_out[0] else: to_out = self.to_out def forward(x, context=None, mask=None): batch_size, sequence_length, dim = x.shape h = self.heads q = self.to_q(x) is_cross = context is not None context = context if is_cross else x k = self.to_k(context) v = self.to_v(context) q = self.reshape_heads_to_batch_dim(q) k = self.reshape_heads_to_batch_dim(k) v = self.reshape_heads_to_batch_dim(v) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale # 考虑没有mask的情况,该部分代码没有仔细研究 if mask is not None: mask = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of # Record labels per pixel attn = sim.softmax(dim=-1) # attn.shape [16, 4096, 4096] [16, 1024, 1024] [16, 256, 256] [16, 64, 64] # [16, 4096, 77] [16, 1024, 77] [16, 256, 77] [16, 64, 77] # only one prompt was inputed attn = controller(attn, is_cross, place_in_unet) out = torch.einsum("b i j, b j d -> b i d", attn, v) out = self.reshape_batch_dim_to_heads(out) # out.shape -> [2, 4096, 320] [2, 1024, 640] [2, 256, 1280] # [2, 64, 1280] [2, 256, 1280] [2, 1024, 640] [2, 4096, 320] return to_out(out) return forward class DummyController: def __call__(self, *args): return args[0] def __init__(self): self.num_att_layers = 0 if controller is None: controller = DummyController() def register_recr(net_, count, place_in_unet): if net_.__class__.__name__ == 'CrossAttention': # return a forward method used to calculate the attention map net_.forward = ca_forward(net_, place_in_unet) return count + 1 elif hasattr(net_, 'children'): for net__ in net_.children(): count = register_recr(net__, count, place_in_unet) # record the total number of layers of calculating the attention map return count cross_att_count = 0 sub_nets = model.unet.named_children() for net in sub_nets: if "down" in net[0]: cross_att_count += register_recr(net[1], 0, "down") elif "up" in net[0]: cross_att_count += register_recr(net[1], 0, "up") elif "mid" in net[0]: cross_att_count += register_recr(net[1], 0, "mid") controller.num_att_layers = cross_att_count # controller.num_att_layers = 32
得到最终的生成结果
-
显示每个token的交叉注意力图
show_cross_attention(controller, res=16, from_where=("up", "down")) def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): tokens = tokenizer.encode(prompts[select]) decoder = tokenizer.decode attention_maps = aggregate_attention(attention_store, res, from_where, True, select) images = [] for i in range(len(tokens)): # 摘取每一个token的注意力图->[16, 16, 1] image = attention_maps[:, :, i] # 转化为灰度图像 image = 255 * image / image.max() # 将单通道灰度图像转化为三通道图像->[16, 16, 3] image = image.unsqueeze(-1).expand(*image.shape, 3) image = image.numpy().astype(np.uint8) image = np.array(Image.fromarray(image).resize((256, 256))) # 给展示的图像加底部文字 image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) images.append(image) ptp_utils.view_images(np.stack(images, axis=0))
总结:
按照StableDiffusionPipeline进行操作,修改Unet中CrossAttention的forward方法,记录每一次的注意力计算图。选取特定分辨率的注意力图进行加权求平均得到最终的交叉注意力图。