nanodiffusion代码逐行理解之time embedding


time embedding本质上就是把时间步t转换为指定维度的嵌入向量,这个向量由时间步
,向量维度,周期等参数决定,可以简单理解为根据时间步t在正弦余弦函数采样得到的向量。

一、time embedding调用过程

传输时间time,得到嵌入向量

    def forward(self, x, time=None, y=None):
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None

二、time embedding定义过程

PositionalEmbedding相当于创建的一个查询表,根据time,返回向量,再经过nn.Linear()->nn.SiLU()->nn.Linear()得到最终的嵌入向量

class UNet(nn.Module):
    def __init__(
        self,
        img_channels,
        base_channels,
        channel_mults=(1, 2, 4, 8),
        num_res_blocks=2,
        time_emb_dim=None,
        time_emb_scale=1.0,
        num_classes=None,
        activation=F.relu,
        dropout=0.1,
        attention_resolutions=(),
        norm="gn",
        num_groups=32,
        initial_pad=0,
    ):
        super().__init__()
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None

三、PositionalEmbedding定义过程

class PositionalEmbedding(nn.Module):
    __doc__ = r"""Computes a positional embedding of timesteps.

    Input:
        x: tensor of shape (N)
    Output:
        tensor of shape (N, dim)
    Args:
        dim (int): embedding dimension
        scale (float): linear scale to be applied to timesteps. Default: 1.0
    """

    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

下面是transformer中位置嵌入的一个定义。
在这里插入图片描述
在这里插入图片描述
接下来是这段代码对应的位置嵌入的一个定义。
上述代码的emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
对应下面这个式子torch.exp(torch.arange(half_dim) / half_dim)*1/10000
也就是
在这里插入图片描述
在这里插入图片描述

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这段代码是用于实现Vision Transformer框架的一部分功能,具体逐行解析如下: 1. `conv_output = F.conv2d(image, kernel, stride=stride)`: 这一行代码使用PyTorch中的卷积函数`F.conv2d`来对输入图像进行卷积操作。 2. `bs, oc, oh, ow = conv_output.shape`: 这一行代码通过`conv_output.shape`获取卷积输出张量的形状信息,其中`bs`表示批次大小,`oc`表示输出通道数,`oh`和`ow`分别表示输出张量的高度和宽度。 3. `patch_embedding = conv_output.reshape((bs, oc, oh*ow))`: 这一行代码通过`reshape`函数将卷积输出张量进行形状变换,将其转换为形状为`(bs, oc, oh*ow)`的张量。 4. `patch_embedding = patch_embedding.transpose(-1, -2)`: 这一行代码使用`transpose`函数交换张量的最后两个维度,将形状为`(bs, oh*ow, oc)`的张量转换为`(bs, oc, oh*ow)`的张量。 5. `weight = weight.transpose(0, 1)`: 这一行代码将权重张量进行转置操作,交换第0维和第1维的位置。 6. `kernel = weight.reshape((-1, ic, patch_size, patch_size))`: 这一行代码通过`reshape`函数将权重张量进行形状变换,将其转换为形状为`(outchannel*inchannel, ic, patch_size, patch_size)`的张量。 7. `patch_embedding_conv = image2emb_conv(image, kernel, patch_size)`: 这一行代码调用了`image2emb_conv`函数,并传入了图像、权重张量和补丁大小作为参数。 8. `print(patch_embedding_conv.shape)`: 这一行代码打印了`patch_embedding_conv`的形状信息。 以上是对Vision Transformer代码逐行解析。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值