IDDPM中的Unet网络(整体代码在文末)

网络模块

首先,我做的是超分辨率任务,所以不需要处理条件的代码。

在我学习IDDPM和DDPM时,对Unet有两个疑惑:

        1:时间t怎么用于模型的预测?

        2:attention机制时怎么融入进去的?

先回答1,其实很简单,与nlp一样,使用Embedding将t融入模型的预测。

class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb) # [d_model/2]
        pos = torch.arange(T).float() # [t]
        # emb = pos[:, None] * emb[None, :]
        # emb = pos[T,1]*emb[1,d_model/2] = [T,d_model/2]
        emb = pos.unsqueeze(1) * emb.unsqueeze(0)
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) # 使用stack可以创建一个新维度堆叠cat不能只能拼接
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)
        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb), # 上面的所有都只是为了预训练权重而已,简单地说就是nn.Embedding(T, d_model)
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
    def forward(self, t):
        # 如果输入是[B],那么输出就是[B,dim]
        emb = self.timembedding(t-1)
        return emb

这段代码看起来很长,很唬人,其实可以简化为

class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        self.timembedding = nn.Sequential(
            nn.Embedding(T, d_model),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
    def forward(self, t):
        # 如果输入是[B],那么输出就是[B,dim]
        emb = self.timembedding(t-1)
        return emb

源码中前面一大块代码就是为了提供一个预训练。在TimeEmbedding中,使用nn.Embedding先将步骤T([B])嵌入为[B,d_model],再通过一个全连接层转换为[B,dim]

对于问题二

简单的说,源码中先实现了一个Resblock,每个Resblock中包含两个卷积,第一个负责改变通道数,第二个负责学习特征图内容。我把他做成图就一目了然了。GroupNorm可以用BN直接替代,它与BN互有胜负就不多说了。Resblock的输入为特征图和嵌入后的时间编码,里面的block1(不是bock,这里打错了)仅对特征图,然后与扩展后的时间编码融合后,经过block2。每个block中都有一个卷积。再Resblock的最后可以选择该block中有没有attention机制,如果有,就在该block中增加一个Multi-head self attention(MSA),并且MSA不会改变任何维度的大小。

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.BatchNorm2d(in_ch), # 对 通道数 分组进行归一化,注意是通道数
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            # 对时间编码 temb 进行投影的模块,将时间编码的维度 tdim 转换为 out_ch
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.BatchNorm2d(out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = Attention(out_ch)
        else:
            self.attn = nn.Identity()

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None] # 通过广播相加
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h

代码非常的简单就不过多的解释

Unet网络

这是我根据源码设计简化的一个Unet网络

 先说说与源码的区别,我的代码封装非常的差,也没有解耦,要不同的模型需要更改代码,不过更改起来也很简单。源码可以根据参数不同得到不同的网络结构。根据源码的超参数,每一层都使用了两个Resblock,并且使用了更多的attention机制。

再说我修改的代码,因为设备太差只能缩小模型结构。我只在得出IDDPM需要的预测的线性权重处使用了attention机制,也就是middleblock处,得到预测的线性权重相当于一个二分类网络,如果是DDPM可以直接将其去掉。我相信大伙光看这个结构图就已经懂了,下来上代码,你可以根据你的设备和需要再Unet中修改你的网络结构。

class UNet(nn.Module):
    def __init__(self, T=1000, in_ch=3, out_ch=3, dropout=0.1, tdim=512):
        # tdim:时间编码的嵌入维度
        super(UNet, self).__init__()
        # 时间编码
        self.time_embedding = TimeEmbedding(T, 128, tdim)
        # 编码器部分
        self.res1 = ResBlock(3,64,tdim,dropout)
        self.pool1 = nn.MaxPool2d(2)
        self.res2 = ResBlock(64, 128, tdim, dropout)
        self.pool2 = nn.MaxPool2d(2)
        self.res3 = ResBlock(128,256, tdim, dropout)
        self.pool3 = nn.MaxPool2d(2)
        self.res4 = ResBlock(256,512, tdim, dropout)
        self.pool4 = nn.MaxPool2d(2)
        # 中间部分,引入self-attention机制
        self.midblock = ResBlock(512, 1024,tdim,dropout,attn=True)
        # 通过中间结果去预测线性权重
        self.var_weight = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1024,512),
            nn.Linear(512,1),
            nn.Sigmoid()
        )
        # 解码器部分
        self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.res5 = ResBlock(1024,512,tdim,dropout)
        self.up2 = nn.ConvTranspose2d(512,256, 2, stride=2)
        self.res6 = ResBlock(512,256, tdim, dropout)
        self.up3 = nn.ConvTranspose2d(256,128, 2, stride=2)
        self.res7 = ResBlock(256,128, tdim, dropout)
        self.up4 = nn.ConvTranspose2d(128,64, 2, stride=2)
        self.res8 = ResBlock(128,64, tdim, dropout)
        # 输出层
        self.out = nn.Conv2d(64, out_ch, 1)

    def forward(self, x, t):
        temb = self.time_embedding(t)
        # 编码器部分
        x1 = self.res1(x,temb)
        x2 = self.res2(self.pool1(x1),temb)
        x3 = self.res3(self.pool2(x2),temb)
        x4 = self.res4(self.pool3(x3),temb)
        # 中间部分,引入了self-attention机制
        out = self.midblock(self.pool4(x4),temb)
        var_weight = self.var_weight(out) # 预测的线性权重
        # 解码器部分
        out = self.up1(out)
        out = torch.cat([out, x4], dim=1)
        out = self.res5(out,temb)

        out = self.up2(out)
        out = torch.cat([out, x3], dim=1)
        out = self.res6(out,temb)

        out = self.up3(out)
        out = torch.cat([out, x2], dim=1)
        out = self.res7(out,temb)

        out = self.up4(out)
        out = torch.cat([out, x1], dim=1)
        out = self.res8(out,temb)
        out = self.out(out)
        return out,var_weight.squeeze(1) # 返回预测的噪音和预测的线性权重

Unet.py完整代码

import torch.nn as nn
import torch
import torch.nn.functional as F
import math


class Swish(nn.Module):
    '''Swish激活函数'''
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb) # [d_model/2]
        pos = torch.arange(T).float() # [t]
        # emb = pos[:, None] * emb[None, :]
        # emb = pos[T,1]*emb[1,d_model/2] = [T,d_model/2]
        emb = pos.unsqueeze(1) * emb.unsqueeze(0)
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) # 使用stack可以创建一个新维度堆叠cat不能只能拼接
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)
        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb), # 上面的所有都只是为了预训练权重而已,简单地说就是nn.Embedding(T, d_model)
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
    def forward(self, t):
        # 如果输入是[B],那么输出就是[B,dim]
        emb = self.timembedding(t-1)
        return emb

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.BatchNorm2d(in_ch), # 对 通道数 分组进行归一化,注意是通道数
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            # 对时间编码 temb 进行投影的模块,将时间编码的维度 tdim 转换为 out_ch
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.BatchNorm2d(out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = Attention(out_ch)
        else:
            self.attn = nn.Identity()

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None] # 通过广播相加
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class Attention(nn.Module):
    def __init__(self,dim,num_heads=8):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads  # 1024 // 8 = 128
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)  # qkv经过一个linear得到。 768 --> 2304
        self.proj = nn.Linear(dim, dim)  # 一次新的映射

    def forward(self, x):
        B, C, H, W = x.shape
        N = H*W
        x = x.reshape(B,C,-1) # [B,C,N]
        x = torch.einsum('xyz->xzy',x) # [B,N,C]
        qkv = self.qkv(x) # [B,N,3*C]
        qkv = qkv.reshape(B,N,3,self.num_heads,C//self.num_heads) # [B,N,3,8,128]
        qkv = torch.einsum('abcde->cadbe',qkv) # [3,B,8,N,128]
        q, k, v = qkv[0], qkv[1], qkv[2] # [B,8,N,128]
        attn = (q @ torch.einsum('abcd->abdc',k)) * self.scale
        # 按行进行softmax
        attn = attn.softmax(dim=-1)
        x = (attn @ v) # [B,8,N,128]
        x = torch.einsum('abcd->acbd',x) # [B,N,8,128]
        x = x.reshape(B,N,C) # [B,N,1024]
        x = self.proj(x)
        x = x.reshape(B,H,W,C) # [B,H,W,C]
        x = torch.einsum('abcd->adbc',x)
        return x


class UNet(nn.Module):
    def __init__(self, T=1000, in_ch=3, out_ch=3, dropout=0.1, tdim=512):
        # tdim:时间编码的嵌入维度
        super(UNet, self).__init__()
        # 时间编码
        self.time_embedding = TimeEmbedding(T, 128, tdim)
        # 编码器部分
        self.res1 = ResBlock(3,64,tdim,dropout)
        self.pool1 = nn.MaxPool2d(2)
        self.res2 = ResBlock(64, 128, tdim, dropout)
        self.pool2 = nn.MaxPool2d(2)
        self.res3 = ResBlock(128,256, tdim, dropout)
        self.pool3 = nn.MaxPool2d(2)
        self.res4 = ResBlock(256,512, tdim, dropout)
        self.pool4 = nn.MaxPool2d(2)
        # 中间部分,引入self-attention机制
        self.midblock = ResBlock(512, 1024,tdim,dropout,attn=True)
        # 通过中间结果去预测线性权重
        self.var_weight = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1024,512),
            nn.Linear(512,1),
            nn.Sigmoid()
        )
        # 解码器部分
        self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.res5 = ResBlock(1024,512,tdim,dropout)
        self.up2 = nn.ConvTranspose2d(512,256, 2, stride=2)
        self.res6 = ResBlock(512,256, tdim, dropout)
        self.up3 = nn.ConvTranspose2d(256,128, 2, stride=2)
        self.res7 = ResBlock(256,128, tdim, dropout)
        self.up4 = nn.ConvTranspose2d(128,64, 2, stride=2)
        self.res8 = ResBlock(128,64, tdim, dropout)
        # 输出层
        self.out = nn.Conv2d(64, out_ch, 1)

    def forward(self, x, t):
        temb = self.time_embedding(t)
        # 编码器部分
        x1 = self.res1(x,temb)
        x2 = self.res2(self.pool1(x1),temb)
        x3 = self.res3(self.pool2(x2),temb)
        x4 = self.res4(self.pool3(x3),temb)
        # 中间部分,引入了self-attention机制
        out = self.midblock(self.pool4(x4),temb)
        var_weight = self.var_weight(out) # 预测的线性权重
        # 解码器部分
        out = self.up1(out)
        out = torch.cat([out, x4], dim=1)
        out = self.res5(out,temb)

        out = self.up2(out)
        out = torch.cat([out, x3], dim=1)
        out = self.res6(out,temb)

        out = self.up3(out)
        out = torch.cat([out, x2], dim=1)
        out = self.res7(out,temb)

        out = self.up4(out)
        out = torch.cat([out, x1], dim=1)
        out = self.res8(out,temb)
        out = self.out(out)
        return out,var_weight.squeeze(1) # 返回预测的噪音和预测的线性权重

  • 11
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
以下是一个使用Keras框架实现U-Net网络加入注意力机制的示例代码,其使用了通道注意力机制: ```python from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, GlobalAveragePooling2D, Reshape, Dense, multiply, Activation from keras.models import Model # 定义通道注意力机制模块 def channel_attention(input_feature, ratio=4): channel = input_feature._keras_shape[-1] avg_pool = GlobalAveragePooling2D()(input_feature) avg_pool = Reshape((1, 1, channel))(avg_pool) assert avg_pool._keras_shape[1:] == (1,1,channel) fc1 = Dense(channel//ratio, activation='relu', kernel_initializer='he_normal', use_bias=True)(avg_pool) assert fc1._keras_shape[1:] == (1, 1, channel//ratio) fc2 = Dense(channel, activation='sigmoid', kernel_initializer='he_normal', use_bias=True)(fc1) assert fc2._keras_shape[1:] == (1, 1, channel) return multiply([input_feature, fc2]) # 定义U-Net网络模型 def unet_with_attention(input_shape=(256,256,3)): inputs = Input(input_shape) conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3) pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4) drop4 = Dropout(0.5)(conv4) pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5) drop5 = Dropout(0.5)(conv5) up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(drop5)) att6 = channel_attention(conv4) up6 = concatenate([att6, up6], axis=3) conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up6) conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6) up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv6)) att7 = channel_attention(conv3) up7 = concatenate([att7, up7], axis=3) conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7) conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7) up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv7)) att8 = channel_attention(conv2) up8 = concatenate([att8, up8], axis=3) conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8) conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8) up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2, 2))(conv8)) att9 = channel_attention(conv1) up9 = concatenate([att9, up9], axis=3) conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9) conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9) conv10 = Conv2D(1, 1, activation='sigmoid')(conv9) model = Model(inputs=inputs, outputs=conv10) return model ``` 在这个示例代码,首先定义了一个名为`channel_attention`的通道注意力机制模块,该模块接受一个输入特征图,使用全局平均池化层、全连接层和sigmoid激活函数实现通道注意力机制,最后将注意力加权后的特征图作为输出。 然后,在U-Net网络模型的downsampling和upsampling路径,分别在每一个卷积层之后和上采样层之前添加了通道注意力机制模块。其,在上采样路径使用了`concatenate`函数将上一层的注意力加权后的特征图和下一层的特征图进行拼接。 最后,将经过注意力加权的特征图送入一个1x1的卷积层并使用sigmoid激活函数进行二分类预测,得到最终的模型输出。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值