“Prompt-to-Prompt Image Editing with Cross Attention Control ”代码分析(一)

代码地址:https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb
论文地址:https://arxiv.org/abs/2208.01626

Cross-Attention Visualization

  1.  设置随机数生成器以及prompt
    g_cpu = torch.Generator().manual_seed(8888)
    prompts = ["A painting of a squirrel eating a burger"]
  2. 申明一个注意力图控制器对象
    controller = AttentionStore()

    类AttentionStore继承于抽象类AttentionControl,需要着重关注以下几个方法:

    class AttentionControl(abc.ABC):   
         def __call__(self, attn, is_cross: bool, place_in_unet: str):
            """
            分两种情况:1.低显存模式->LOW_RESOURCE = True
                         条件嵌入和无条件嵌入串行进行运算,需要先越过无条件嵌入部分结果,只保留条件嵌入部分的注意力图
                       2.高显存模式->LOW_RESOURCE = False
                         条件嵌入和无条件嵌入并行进行运算,直接取后半部分的条件嵌入的注意力图
            """
            if self.cur_att_layer >= self.num_uncond_att_layers:
                if LOW_RESOURCE:
                    attn = self.forward(attn, is_cross, place_in_unet)
                else:
                    # [uncond_embeddings, text_embeddings]*8 
                    h = attn.shape[0]
                    attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
            self.cur_att_layer += 1
            # 当所有层的注意力图都保存后(即完成了一个时间步的去噪后),对每个层的注意力图进行分别求和
            if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
                self.cur_att_layer = 0
                self.cur_step += 1
                self.between_steps()
            return attn
        ...
    
    class AttentionStore(AttentionControl):
    
        def forward(self, attn, is_cross: bool, place_in_unet: str):
            """
            保存分辨率小于32*32的注意力图
            """
            key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
            # Only the attention maps of layers with resulutions below 32*32 are recorded
            # attn.shape[1] 4096 1024 256 64
            if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
                self.step_store[key].append(attn)
            return attn
    
        def between_steps(self):
            # 将每一步中各层的参数求和, 为了下一步的平均做准备
            if len(self.attention_store) == 0:
                self.attention_store = self.step_store
            else:
                for key in self.attention_store:
                    for i in range(len(self.attention_store[key])):
                        self.attention_store[key][i] += self.step_store[key][i]
            self.step_store = self.get_empty_store()
    
        def get_average_attention(self):
            # 对所有时间步的各层注意力图进行加权求平均,得到平均注意力图
            # cur_step = 51 0-51一共51步
            average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
            return average_attention
    
    
        ...
    
       
  3. 运行文生图Pipeline并显示最后的生成结果
    image, x_t = run_and_display(prompts, controller, latent=None, run_baseline=False, generator=g_cpu)
    
    def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None):
        if run_baseline:
            print("w.o. prompt-to-prompt")
            images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
            print("with prompt-to-prompt")
        images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DIFFUSION_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE)
        ptp_utils.view_images(images)
        return images, x_t
    @torch.no_grad()
    def text2image_ldm_stable(
        model,
        prompt: List[str],
        controller,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        generator: Optional[torch.Generator] = None,
        latent: Optional[torch.FloatTensor] = None,
        low_resource: bool = False,
    ):
        register_attention_control(model, controller)
        height = width = 512
        # get the number of prompts
        batch_size = len(prompt)
    
        text_input = model.tokenizer(
            prompt,
            padding="max_length",
            max_length=model.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
        max_length = text_input.input_ids.shape[-1]
        uncond_input = model.tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
        )
        uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
        
        context = [uncond_embeddings, text_embeddings]
        # If the gpu memory is small, calculate separately, otherwise calculate simultaneously
        if not low_resource:
            context = torch.cat(context)
        latent, latents = init_latent(latent, model, height, width, generator, batch_size)
        
        # set timesteps
        #extra_set_kwargs = {"offset": 1}
        #model.scheduler.set_timestep(num_inference_steps, **extra_set_kwargs)
        model.scheduler.set_timesteps(num_inference_steps)
        # 逐步去噪
        for t in tqdm(model.scheduler.timesteps):
            latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
        
        image = latent2image(model.vae, latents)
      
        return image, latent
    需要着重关注register_attention_control(model, controller)方法,其主要作用为:修改Unet中每一层CrossAttention类的forward()方法,使得Unet在做注意力计算时可以保存其注意力图的参数。
    def register_attention_control(model, controller):
        def ca_forward(self, place_in_unet):
            to_out = self.to_out
            if type(to_out) is torch.nn.modules.container.ModuleList:
                to_out = self.to_out[0]
            else:
                to_out = self.to_out
    
            def forward(x, context=None, mask=None):
                batch_size, sequence_length, dim = x.shape
                h = self.heads
                q = self.to_q(x)
                is_cross = context is not None
                context = context if is_cross else x
                k = self.to_k(context)
                v = self.to_v(context)
                q = self.reshape_heads_to_batch_dim(q)
                k = self.reshape_heads_to_batch_dim(k)
                v = self.reshape_heads_to_batch_dim(v)
    
                sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
                # 考虑没有mask的情况,该部分代码没有仔细研究
                if mask is not None:
                    mask = mask.reshape(batch_size, -1)
                    max_neg_value = -torch.finfo(sim.dtype).max
                    mask = mask[:, None, :].repeat(h, 1, 1)
                    sim.masked_fill_(~mask, max_neg_value)
    
                # attention, what we cannot get enough of
                # Record labels per pixel
                attn = sim.softmax(dim=-1)
                # attn.shape [16, 4096, 4096] [16, 1024, 1024] [16, 256, 256] [16, 64, 64]
                #            [16, 4096, 77] [16, 1024, 77] [16, 256, 77] [16, 64, 77]
                # only one prompt was inputed
                attn = controller(attn, is_cross, place_in_unet)
                out = torch.einsum("b i j, b j d -> b i d", attn, v)
                out = self.reshape_batch_dim_to_heads(out)
                # out.shape -> [2, 4096, 320] [2, 1024, 640] [2, 256, 1280] 
                # [2, 64, 1280] [2, 256, 1280] [2, 1024, 640] [2, 4096, 320]
                return to_out(out)
    
            return forward
    
        class DummyController:
    
            def __call__(self, *args):
                return args[0]
    
            def __init__(self):
                self.num_att_layers = 0
    
        if controller is None:
            controller = DummyController()
    
        def register_recr(net_, count, place_in_unet):
            if net_.__class__.__name__ == 'CrossAttention':
                # return a forward method used to calculate the attention map
                net_.forward = ca_forward(net_, place_in_unet)
                return count + 1
            elif hasattr(net_, 'children'):
                for net__ in net_.children():
                    count = register_recr(net__, count, place_in_unet)
            # record the total number of layers of calculating the attention map
            return count
    
        cross_att_count = 0
        sub_nets = model.unet.named_children()
        for net in sub_nets:
            if "down" in net[0]:
                cross_att_count += register_recr(net[1], 0, "down")
            elif "up" in net[0]:
                cross_att_count += register_recr(net[1], 0, "up")
            elif "mid" in net[0]:
                cross_att_count += register_recr(net[1], 0, "mid")
    
        controller.num_att_layers = cross_att_count
        # controller.num_att_layers = 32
    得到最终的生成结果
  4. 显示每个token的交叉注意力图
    show_cross_attention(controller, res=16, from_where=("up", "down"))
    
    def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
        tokens = tokenizer.encode(prompts[select])
        decoder = tokenizer.decode
        attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
        images = []
        for i in range(len(tokens)):
            # 摘取每一个token的注意力图->[16, 16, 1]
            image = attention_maps[:, :, i]
            # 转化为灰度图像
            image = 255 * image / image.max()
            # 将单通道灰度图像转化为三通道图像->[16, 16, 3]
            image = image.unsqueeze(-1).expand(*image.shape, 3)
            image = image.numpy().astype(np.uint8)
            image = np.array(Image.fromarray(image).resize((256, 256)))
            # 给展示的图像加底部文字
            image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
            images.append(image)
        ptp_utils.view_images(np.stack(images, axis=0))
    

总结:

        按照StableDiffusionPipeline进行操作,修改Unet中CrossAttention的forward方法,记录每一次的注意力计算图。选取特定分辨率的注意力图进行加权求平均得到最终的交叉注意力图。

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值