【DA-CLIP】encode_image图像编码过程和IRSDE对image_context,、degra_context的使用

 背景:

编码过程:

  with torch.no_grad(), torch.cuda.amp.autocast():
        # 这一行开始一个上下文管理器,用于关闭梯度计算(torch.no_grad()),这对于推理阶段是必要的,因为我们不需要计算反向传播。
        # torch.cuda.amp.autocast()用于自动将操作转换为半精度浮点数,这可以提高计算速度并减少内存使用
        image_context, degra_context = clip_model.encode_image(img4clip, control=True)
        # 这一行使用clip_model的encode_image方法来编码处理后的图像,生成图像的上下文信息。
        # control=True使用DA-CLIP的encode_image而非CLIP
        #
        image_context = image_context.float()
        # 这一行将图像上下文张量转换为浮点数类型。
        degra_context = degra_context.float()
        # 同上将degra_context张量转换为浮点数类型

根据 复原过程探究

【DA-CLIP】复原过程代码解读-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_60350022/article/details/137699012image_context,、degra_context被传入ConditionUnet的forward()计算noise张量再计算分数score

encode_image过程

已知clip_model是DaCLIP类实例,在open_clip/declip_model.py找到该方法定义

    def encode_image(self, image, control=False, normalize: bool = False):
        if control:
            degra_features, hiddens = self.visual_control(image, output_hiddens=True)
            image_features = self.clip.visual(image, control=hiddens)
            
            image_features = F.normalize(image_features, dim=-1) if normalize else image_features
            degra_features = F.normalize(degra_features, dim=-1) if normalize else degra_features
            return image_features, degra_features
        else:
            return self.clip.encode_image(image, normalize)
            # 相当于image_features = self.clip.visual(image, control=hiddens)
            # image_features = F.normalize(image_features, dim=-1) if normalize else image_features

DaCLIP类self.clip是clip实例,control为False,只返回一个image_feature没有degra_feature就是CLIP原始的图像编码过程。

    def encode_image(self, image, normalize: bool = False):
        features = self.visual(image)
        return F.normalize(features, dim=-1) if normalize else features

一、degra_features解析。

self.visual_control()

self.visual_control = copy.deepcopy(clip_model.visual)
self.visual_control.transformer = ControlTransformer(self.visual_control.transformer)

 visual_control方法是clip_model.visual方法的深拷贝,但是更换了transformer 

_build_vision_tower ()

clip_model.visual是_build_vision_tower方法的返回值,

这个函数通过不同的配置参数来决定使用哪种类型的模型,包括使用 timm 库中的模型、ModifiedResNet 或者 VisionTransformer。函数中使用了条件判断来根据不同的配置参数构建相应的视觉模型,并且对于不同的数据类型需求,选择了不同的归一化层。

 self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)

该方法接受 vision_cfg配置

class CLIPVisionCfg:
    # 定义Vision Transformer模型中的层数,可以是一个整数或者一个包含四个整数的元组,
    # 如果是元组,它代表了不同层的维度,通常用于构建具有不同宽度的模型。
    layers: Union[Tuple[int, int, int, int], int] = 12
    # 模型的宽度,即每个层中的通道数。
    width: int = 768
    # 头部的宽度,即在多头自注意力机制中每个头的维度。
    head_width: int = 64
    # 多层感知机(MLP)层的宽度与输入的宽度之比。
    mlp_ratio: float = 4.0
    # 图像分块的大小,即将输入图像划分为多少个小块(patches)。
    patch_size: int = 16
    # 输入图像的尺寸,可以是一个整数或者一个包含两个整数的元组。
    image_size: Union[Tuple[int, int], int] = 224

    # 层归一化(Layer Scale)的初始值。
    ls_init_value: Optional[float] = None
    # 在训练过程中要丢弃的补丁(patches)的比例。值为0表示不丢弃任何补丁,
    # 而0.5到0.75之间是推荐值,用于优化模型性能。
    patch_dropout: float = 0.
    # 是否在每个补丁上使用输入层归一化,即是否对每个补丁应用输入层的层归一化。
    input_patchnorm: bool = False
    # 是否在最后一个嵌入层使用全局平均池化(Global Average Pooling),而不是使用CLS标记。
    global_average_pool: bool = False
    # 是否在最后一个嵌入层使用注意力池化(Attentional Pooling)。
    attentional_pool: bool = False
    # 注意力池化器(Attentional Pooler)的查询数。
    n_queries: int = 256
    # 注意力池化器的头数。
    attn_pooler_heads: int = 8
    # 是否输出token,通常用于模型的最终输出。
    output_tokens: bool = False

    # Timm库中的有效模型名称,如果设置了这个值,将会覆盖layers、width和patch_size等配置。
    timm_model_name: str = None
    # 是否使用预训练权重(通常是在ImageNet数据集上预训练的)。
    timm_model_pretrained: bool = False
    # Timm模型的特征池化类型。
    timm_pool: str = 'avg'
    # Timm模型输出的线性投影类型。
    timm_proj: str = 'linear'
    # 是否启用最终投影的偏置项。
    timm_proj_bias: bool = False
    # 头部的dropout比率。
    timm_drop: float = 0.
    # Timm模型的路径丢弃比率,也称为随机深度(Stochastic Depth)。
    timm_drop_path: Optional[float] = None

下面是_build_vision_tower 定义,在判断中

timm_model_name初始为None

vision_cfg.layers为12(int)

以上判断不满足,所以进入else。根据cfg配置VisionTransformer

visual = VisionTransformer()

# build_vision_tower,它接受以下参数:
# embed_dim: 整数,表示嵌入维度。
# vision_cfg: CLIPVisionCfg 类型,包含了配置视觉模型所需的参数。
# quick_gelu: 布尔值,可选参数,默认为 False。用于指定是否使用 QuickGELU 激活层。
# cast_dtype: 可选参数,默认为 None,表示在 PyTorch 中的类型转换。

def _build_vision_tower(
        embed_dim: int,
        vision_cfg: CLIPVisionCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None
):
    # 如果传入的 vision_cfg 是一个字典,将其转换为 CLIPVisionCfg 类的实例。
    if isinstance(vision_cfg, dict):
        vision_cfg = CLIPVisionCfg(**vision_cfg)

    # 根据 quick_gelu 参数的值选择使用 QuickGELU 还是原生的 nn.GELU 激活层。
    # QuickGELU 在 OpenAI 的模型中预训练使用,但在 PyTorch 1.10 及更高版本中,原生的 GELU 更快且更节省内存。
    # 注意:无论 quick_gelu 标志如何,timm 模型始终使用原生 GELU。
    act_layer = QuickGELU if quick_gelu else nn.GELU

    # 如果 vision_cfg 中包含了 timm_model_name,表示使用 timm 库中的模型。
    if vision_cfg.timm_model_name:
        visual = TimmModel(
            vision_cfg.timm_model_name,  # 使用指定的 timm 模型名称。
            pretrained=vision_cfg.timm_model_pretrained,  # 是否使用预训练权重。
            pool=vision_cfg.timm_pool,  # 指定池化层的类型。
            proj=vision_cfg.timm_proj,  # 投影层的输出维度。
            proj_bias=vision_cfg.timm_proj_bias,  # 投影层是否包含偏置项。
            drop=vision_cfg.timm_drop,  # 指定 dropout 比率。
            drop_path=vision_cfg.timm_drop_path,  # 指定 drop path 正则化的比率。
            patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,  # 指定 patch dropout 比率。
            embed_dim=embed_dim,  # 嵌入维度。
            image_size=vision_cfg.image_size,  # 输入图像的尺寸。
        )
    # 如果 vision_cfg 中包含了 layers,且类型为元组或列表。
    elif isinstance(vision_cfg.layers, (tuple, list)):
        # 计算视觉头的数量。
        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
        # 使用 ModifiedResNet 构建视觉模型。
        visual = ModifiedResNet(
            layers=vision_cfg.layers,  # 指定 ResNet 的层数。
            output_dim=embed_dim,  # 输出维度。
            heads=vision_heads,  # 视觉头的数量。
            image_size=vision_cfg.image_size,  # 输入图像的尺寸。
            width=vision_cfg.width,  # 模型的宽度。
        )
    # 如果以上条件都不满足,则使用 VisionTransformer 构建视觉模型。
    else:
        # 计算视觉头的数量。
        vision_heads = vision_cfg.width // vision_cfg.head_width
        # 根据 cast_dtype 的值选择使用不同的 LayerNorm 层。
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
        # 使用 VisionTransformer 构建视觉模型。
        visual = VisionTransformer(
            image_size=vision_cfg.image_size,  # 输入图像的尺寸。
            patch_size=vision_cfg.patch_size,  # 补丁的大小。
            width=vision_cfg.width,  # 模型的宽度。
            layers=vision_cfg.layers,  # 指定 Transformer 的层数。
            heads=vision_heads,  # 视觉头的数量。
            mlp_ratio=vision_cfg.mlp_ratio,  # MLP 的比例。
            ls_init_value=vision_cfg.ls_init_value,  # 层归一化初始化值。
            patch_dropout=vision_cfg.patch_dropout,  # 指定补丁 dropout 比率。
            input_patchnorm=vision_cfg.input_patchnorm,  # 是否使用输入补丁归一化。
            global_average_pool=vision_cfg.global_average_pool,  # 是否使用全局平均池化。
            attentional_pool=vision_cfg.attentional_pool,  # 是否使用注意力池化。
            n_queries=vision_cfg.n_queries,  # 查询的数量。
            attn_pooler_heads=vision_cfg.attn_pooler_heads,  # 注意力池化头的数量。
            output_tokens=vision_cfg.output_tokens,  # 输出的 token 数量。
            output_dim=embed_dim,  # 输出维度。
            act_layer=act_layer,  # 激活层。
            norm_layer=norm_layer,  # 归一化层。
        )

    # 返回构建好的视觉模型实例。
    return visual

 配置好VisionTransformer之后更换transformer,

 self.visual_control深拷贝clip_model.visual,self.visual_control.transformer就是ViT的transformer,

  self.transformer = Transformer(
            width,
            layers,
            heads,
            mlp_ratio,
            ls_init_value=ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer
        )

ControlTransformer()

使用自定义的ControlTransformer,其构造函数中接受一个 transformer 参数,该参数是 Transformer 类的一个实例。ControlTransformer 类的主要目的是通过添加额外的控制机制来增强原始 Transformer 模型的功能。这种控制机制通过一个 control 张量实现,该张量可以在模型的前向传播过程中与每个残差块的输出相加,从而对模型的输出进行调节。

import torch
import torch.nn as nn
from typing import Optional, Callable

# 假设 LayerNorm 和 ResidualAttentionBlock 已经在其他地方定义
# LayerNorm = ... 
# ResidualAttentionBlock = ...

class ControlTransformer(nn.Module):
    def __init__(self, transformer: nn.Module):
        """
        初始化 ControlTransformer 实例。
        
        :param transformer: 一个 nn.Module 类型的实例,它应该是一个 Transformer 模型。
                              ControlTransformer 将使用这个模型的层数和宽度,并添加控制机制。
        """
        super().__init__()  # 调用父类的构造函数
        self.transformer = transformer  # 保存传入的 Transformer 实例
        self.layers = transformer.layers  # 获取 Transformer 的层数
        self.width = transformer.width  # 获取 Transformer 的宽度(即特征维度)

        # 创建一个 ModuleList,其中包含与 Transformer 层数相同数量的零模块
        # 每个零模块都是一个线性层,用于生成全零的向量
        self.zero_modules = nn.ModuleList([
            self.zero_module(nn.Linear(self.width, self.width, 1))
            for _ in range(self.layers)]).cuda()  # 将零模块移动到 GPU 上(如果可用)

        # 保存 Transformer 实例的梯度检查点(grad checkpointing)设置
        self.grad_checkpointing = transformer.grad_checkpointing

    # 定义一个辅助函数,用于将模块的参数归零
    def zero_module(self, module: nn.Module) -> nn.Module:
        """
        将传入模块的参数归零并返回该模块。
        
        :param module: 一个 nn.Module 类型的实例,其参数将被归零。
        :return: 传入的模块,其参数已被归零。
        """
        for p in module.parameters():
            p.detach().zero_()  # 归零参数
        return module  # 返回修改后的模块

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, 
                output_hiddens: Optional[bool] = False, 
                control: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
        """
        前向传播函数。
        
        :param x: 输入张量。
        :param attn_mask: 可选的注意力掩码张量。
        :param output_hiddens: 可选的布尔值,指示是否输出每层的隐藏状态。
        :param control: 可选的控制张量,用于调节模型的输出。
        :return: 如果 output_hiddens 为 True,则返回输入 x 和每层的隐藏状态(hiddens);
                 否则只返回输入 x。
        """
        if output_hiddens:
            hiddens = []  # 如果需要输出隐藏状态,则初始化一个列表来存储它们

        # 遍历 Transformer 的每一层和对应的零模块
        for z, r in zip(self.zero_modules, self.transformer.resblocks):
            # 如果启用了梯度检查点,则使用 checkpoint 装饰器来保存和恢复梯度
            if self.grad_checkpointing and not torch.jit.is_scripting():
               
                x = checkpoint(r, x, None, None, attn_mask)
            else:
                # 否则,直接调用残差块的前向传播方法
                x = r(x, attn_mask=attn_mask)
            
            # 应用零模块生成零向量
            zx = z(x)

            # 如果需要输出隐藏状态,则将其添加到列表中
            if output_hiddens:
                hiddens.append(zx)

            # 如果提供了控制张量,则将其添加到当前层的输出中
            if control is not None:
                x += control.pop()

        # 根据 output_hiddens 的值返回相应的结果
        return (x, hiddens) if output_hiddens else x

ViT的 Transformer

class Transformer(nn.Module):
    def __init__(  # 构造函数定义了如何初始化 Transformer 模型
            self,
            width: int,          # 模型的宽度,即特征维度
            layers: int,         # Transformer 块的层数
            heads: int,          # 多头注意力中头的数量
            mlp_ratio: float = 4.0,  # MLP层的宽度与输入宽度的比率
            ls_init_value: float = None,  # 层归一化(LayerScale)的初始值
            act_layer: Callable = nn.GELU,  # 激活层的类型
            norm_layer: Callable = LayerNorm,  # 归一化层的类型
    ):
        super().__init__()  # 调用父类的构造函数
        self.width = width  # 保存模型宽度
        self.layers = layers  # 保存层数
        self.grad_checkpointing = False  # 默认不启用梯度检查点

        # 创建一个 ModuleList,存储所有的残差注意力块
        self.resblocks = nn.ModuleList([
            ResidualAttentionBlock(  # 为每个层创建一个残差注意力块
                width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer
            ) for _ in range(layers)
        )

    def get_cast_dtype(self) -> torch.dtype:  # 获取用于类型转换的原始数据类型
        # 假设 ResidualAttentionBlock 和其 MLP 层已经在其他地方定义
        # 这里只是一个示例,实际的实现可能需要访问具体的 MLP 层和其属性
        if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
            return self.resblocks[0].mlp.c_fc.int8_original_dtype
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def forward(self,  # 前向传播函数
            x: torch.Tensor,  # 输入张量
            attn_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码
            output_hiddens: Optional[bool] = False,  # 可选的标志,指示是否输出隐藏状态
    ):
        if output_hiddens:
            hiddens = []  # 如果需要输出隐藏状态,则初始化一个列表来存储它们

        for r in self.resblocks:  # 遍历所有的残差注意力块
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # 使用梯度检查点来节省内存
                x = checkpoint(r, x, None, None, attn_mask)
            else:
                x = r(x, attn_mask=attn_mask)  # 调用残差注意力块的前向传播方法

            if output_hiddens:
                hiddens.append(x)  # 如果需要,保存当前层的输出

        return (x, hiddens) if output_hiddens else x  # 根据需要返回最终输出和/或隐藏状态
  1. 控制机制: ControlTransformer 引入了控制机制,它通过一个额外的 control 参数来调节模型的输出。这是 ControlTransformer 独有的特性,而原始的 Transformer 没有这个功能。

  2. 零模块(Zero Modules): ControlTransformer 创建了一个 zero_modulesModuleList,其中包含零初始化的线性层。这些零模块用于生成全零的向量,这些向量可以在前向传播中与残差块的输出相加。原始的 Transformer 没有这个列表和这些零模块。

  3. 梯度检查点(Grad Checkpointing): 虽然两个类都有关于梯度检查点的代码,但 ControlTransformer 在其构造函数中直接从传入的 transformer 参数中继承了这个设置。这意味着 ControlTransformer 可以利用原始 Transformer 的梯度检查点配置。

  4. 前向传播方法(Forward Method): 两个类的 forward 方法在结构上相似,但 ControlTransformer 在处理每层的输出时会考虑 control 参数,并将其与输出相加。

  5. 类之间的关系: ControlTransformer 继承自 nn.Module 并接受一个 transformer 参数,这个参数应该是 Transformer 类的一个实例。这样,ControlTransformer 可以在原始 Transformer 的基础上添加额外的功能。

 

小结 

 self.visual_control()应该是VisionTransformer,Transformer更换了ControlTransformer

 degra_features, hiddens = self.visual_control(image, output_hiddens=True)就是执行VisionTransformer的forward()

def forward(self, x: torch.Tensor, output_hiddens: bool = False, control: Optional[torch.Tensor] = None):
    """
    前向传播函数,定义了 Vision Transformer 模型如何处理输入数据 x 并生成输出。

    :param x: 输入的张量,代表一批图像。
    :param output_hiddens: 布尔值,指示是否输出每层的隐藏状态。
    :param control: 可选的控制张量,用于调节模型的输出。
    :return: 根据模型配置返回不同的输出,可能是池化后的输出、隐藏状态或注意力池化的 token。
    """

    # 如果启用了输入补丁归一化(input patch normalization),则对输入图像的每个补丁进行处理
    if self.input_patchnorm:
        # 使用 einops 库对输入图像进行分块(patchification)
        x = x.reshape(...)  # 重新排列和重塑张量以匹配补丁嵌入的格式
        x = self.patchnorm_pre_ln(x)  # 应用第一个 LayerNorm
        x = self.conv1(x)  # 应用卷积层以生成补丁嵌入
    else:
        # 如果没有启用输入补丁归一化,直接应用卷积层
        x = self.conv1(x)
        # 重塑和排列张量以准备嵌入和嵌入位置
        x = x.reshape(...)  # 重塑张量以匹配补丁嵌入的格式
        x = x.permute(...)  # 排列张量维度以匹配所需的格式

    # 将类嵌入(class embeddings)添加到输入补丁,并与位置嵌入(positional embeddings)相加
    x = torch.cat([self.class_embedding, x], dim=1)

    # 应用补丁dropout(如果dropout比例不为0)
    x = self.patch_dropout(x)

    # 应用预Transformer层的LayerNorm
    x = self.ln_pre(x)

    # 调整张量维度顺序以适应Transformer的输入要求(从 NLD 到 LND)
    x = x.permute(1, 0, 2)

    # 通过Transformer层进行前向传播,并可选择性地输出隐藏状态
    x = self.transformer(x, output_hiddens=output_hiddens, control=control)
    if output_hiddens:
        x, hiddens = x  # 如果需要,解包Transformer的输出以获取隐藏状态

    # 再次调整张量维度顺序(从 LND 回到 NLD)
    x = x.permute(1, 0, 2)

    # 如果模型包含注意力池化器,则应用它并获取池化后的输出和token
    if self.attn_pool is not None:
        x = self.attn_pool(x)
        x = self.ln_post(x)
        pooled, tokens = self._global_pool(x)
    else:
        # 否则,直接应用全局池化并获取池化后的输出
        pooled, tokens = self._global_pool(x)
        pooled = self.ln_post(pooled)

    # 如果存在输出投影层,则应用它
    if self.proj is not None:
        pooled = pooled @ self.proj

    # 如果模型配置为输出token,则返回它们
    if self.output_tokens:
        return pooled, tokens

    # 如果需要输出隐藏状态,则返回它们
    if output_hiddens:
        return pooled, hiddens
    
    # 否则,只返回池化后的输出
    return pooled

返回的是一个包含两个元素的元组,其中第一个元素是 pooled(经过处理的特征表示),第二个元素是 hiddens(每个Transformer层的输出) 

猪脑过载了有机会再探究为什么叫degradation_feature

二、image_features = self.clip.visual(image, control=hiddens)

 这个相比上面的理解就简单了,

self.clip.visual就是VisionTransformer,

image_features = self.clip.visual(image, control=hiddens)同样执行VisionTransformer的forward

代码还是上面那段。

control的输入是刚刚生成degra_features的返回值hiddens

三、image_features = F.normalize(image_features, dim=-1) if normalize else image_features

【Pytorch】F.normalize计算理解-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/lj2048/article/details/118115681默认normalize为False,暂时不去理解这里

score计算过程

预处理

image_context = image_context.float()
# 这一行将图像上下文张量转换为浮点数类型。
degra_context = degra_context.float()
# 同上将degra_context张量转换为浮点数类型

noise? 

 noise = self.model(x, self.mu, t * scale, **kwargs) 

 ConditionalUNet的forward 接收输入张量 xt、条件张量 cond、时间参数 time,以及可选的文本和图像上下文。通过一系列卷积、注意力机制和上/下采样操作,模型提取和融合特征,最终生成输出张量。输出张量经过裁剪,确保其尺寸与输入图像的尺寸一致。

def forward(self, xt, cond, time, text_context=None, image_context=None):
    # 检查输入的时间参数是否为整数或浮点数,如果是,则将其转换为一个单元素张量,并移动到xt所在的设备
    if isinstance(time, int) or isinstance(time, float):
        time = torch.tensor([time]).to(xt.device)
    
    # X=noisy_tensor-LQ_tensor就是文章第一步添加的随机噪声,与LQ_tensor拼接,增加通道维度
    x = xt - cond
    x = torch.cat([x, cond], dim=1)

    # 获取输入张量的空间维度H和W
    H, W = x.shape[2:]
    # 检查并调整输入张量x的空间尺寸以匹配原始图像的尺寸
    x = self.check_image_size(x, H, W)

    # 应用初始卷积层
    x = self.init_conv(x)
    # 克隆x,用于后续操作
    x_ = x.clone()

    # 通过时间MLP处理时间参数
    t = self.time_mlp(time) 
    # 如果上下文维度大于0,并且使用degra上下文,且文本上下文不为空
    if self.context_dim > 0:
        if self.use_degra_context and text_context is not None:
            # 计算文本上下文的嵌入,将其与提示向量结合,并进行处理
            prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
            prompt_embedding = self.prompt_mlp(prompt_embedding)
            # 将处理后的文本上下文嵌入加到时间参数t上
            t = t + prompt_embedding

        # 如果使用图像上下文,且图像上下文不为空
        if self.use_image_context and image_context is not None:
            # 为图像上下文增加一个通道维度
            image_context = image_context.unsqueeze(1)

    # 存储下采样过程中的特征图
    h = []
    # 遍历下采样模块列表
    for b1, b2, attn, downsample in self.downs:
        # 应用第一个残差块和时间参数t
        x = b1(x, t)
        # 存储特征图
        h.append(x)

        # 应用第二个残差块和时间参数t
        x = b2(x, t)
        # 应用注意力机制,如果提供了图像上下文,则使用它
        x = attn(x, context=image_context)
        # 存储特征图
        h.append(x)

        # 应用下采样操作
        x = downsample(x)

    # 应用中间块1和时间参数t
    x = self.mid_block1(x, t)
    # 如果使用图像上下文,则应用注意力机制
    x = self.mid_attn(x, context=image_context) if self.use_image_context else x
    # 应用中间块2和时间参数t
    x = self.mid_block2(x, t)

    # 遍历上采样模块列表
    for b1, b2, attn, upsample in self.ups:
        # 从历史特征图中弹出并拼接特征,与当前特征图拼接
        x = torch.cat([x, h.pop()], dim=1)
        # 应用第一个残差块和时间参数t
        x = b1(x, t)
        
        # 再次从历史特征图中弹出并拼接特征,与当前特征图拼接
        x = torch.cat([x, h.pop()], dim=1)
        # 应用第二个残差块和时间参数t
        x = b2(x, t)

        # 应用注意力机制,如果提供了图像上下文,则使用它
        x = attn(x, context=image_context)
        # 应用上采样操作
        x = upsample(x)

    # 将原始输入xt与当前特征图x拼接,增加通道维度
    x = torch.cat([x, x_], dim=1)

    # 应用最终的残差块和时间参数t
    x = self.final_res_block(x, t)
    # 应用最终的卷积层
    x = self.final_conv(x)

    # 裁剪输出张量x,使其空间尺寸与原始输入图像的尺寸相匹配
    x = x[..., :H, :W].contiguous()
    
    # 返回处理后的输出张量x
    return x

 根据返回的noise计算score

    def get_score_from_noise(self, noise, t):
        return -noise / self.sigma_bar(t)
    def sigma_bar(self, t):
        return self.sigma_bars[t]
      sigma_bars = get_sigma_bars(thetas_cumsum)
thetas_cumsum = get_thetas_cumsum(thetas) - thetas[0] # for that thetas[0] is not 0
def get_sigma_bars(thetas_cumsum):
    return torch.sqrt(max_sigma**2 * (1 - torch.exp(-2 * thetas_cumsum * self.dt)))

 以上完成score计算

  • 17
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值