IDDPM代码ResBlock和TimestepEmbedSequential解读

IDDPM代码ResBlock和TimestepEmbedSequential解读

ResBlock中的forward和_forward的区别

# ResBlock是为了把embedding以残差的形式和图片加起来,即把时间信息融合到图片中去
class ResBlock(TimestepBlock): 
    # resblock是继承自timestepblock的,所以所有的resblock部分肯定是要传入embedding的
    # 而在attention, 上采样,下采样都不需要传入embedding
    """
    A residual block that can optionally change the number of channels.

    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:  # 如果通道数目一致的话,直接连起来就好
            self.skip_connection = nn.Identity()
        elif use_conv:# 如果通道数目不一致的话,可以用一个大小不变的卷积去做
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:# 在有的论文中,如果通道数目不一致的话,也可以用一个1*1的卷积去做逐点的卷积
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )

    def _forward(self, x, emb): # _forward 是私有方法,它执行实际的计算并将其结果返回给 forward
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift # ys = (1+scale), yb = bias
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h    # identity

在这个类中,forward 是公开接口,用于将输入传递给该模块的子模块并返回结果。 _forward 是私有方法,它执行实际的计算并将其结果返回给 forward。在这个类中,我们可以看到 forward 方法调用了 checkpoint,以利用 PyTorch 的自动微分机制来减少内存的使用。然后,_forward 方法执行了所有计算,并将最终结果返回给 forward。因此,我们可以说 forward 方法是 ResBlock 类的外部接口,而 _forward 方法是其内部实现。

TimestepEmbedSequential

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """
    # emb:timestep embedding和condition embedding混合起来的
    def forward(self, x, emb):
        for layer in self:  
            if isinstance(layer, TimestepBlock):# 只有layer是timestepblock的时候才输入emb
                x = layer(x, emb)
            else:
                x = layer(x)
        return x

其调用语句为

TimestepEmbedSequential(
                   conv_nd(dims, in_channels, model_channels, 3, padding=1)
                ) conv_nd的结果是怎么传入TimestepEmbedSequential的?

在调用TimestepEmbedSequential时,我们传递给它一个nn.Module对象,该模块是由conv_nd(dims, in_channels, model_channels, 3, padding=1)创建的。conv_nd是一个工厂函数,它根据给定的参数创建一个卷积层对象。在这里,conv_nd返回一个维度为dims(dims是一个整数,表示卷积层的维度,例如1表示一维卷积,2表示二维卷积,3表示三维卷积),输入通道数为in_channels,输出通道数为model_channels,卷积核大小为3的卷积层对象,并将其作为参数传递给TimestepEmbedSequential的构造函数。

在TimestepEmbedSequential的forward方法中,该卷积层对象将作为nn.Sequential的一个子模块来使用,即被添加到self列表中。在调用forward方法时,输入张量x和时间步骤嵌入张量emb将依次被传递给nn.Sequential中的每个子模块。当遇到一个子模块是TimestepBlock类型时,emb将被传递给该子模块的forward方法,作为额外的输入。在这个例子中,conv_nd不是TimestepBlock类型,所以emb将被忽略,仅传递x给该层的forward方法。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值