pytorch unet+res_block 用于图像序列预测

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init


def transfer(x):
    seq_number, batch_size, input_channel, height, width = x.size()
    x = torch.reshape(x, (-1, input_channel, height, width))
    return x


class basic_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(basic_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding='same'),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding='same'),
        )
        self.residual = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=1, padding='same')
        )

    def forward(self, x):
        x = (self.conv(x) + self.residual(x))
        return x


class down_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(down_block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(ch_in),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, stride=2),
            nn.BatchNorm2d(ch_in),
            nn.LeakyReLU(0.1),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding='same'),
        )
        self.residual = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=2)
        )

    def forward(self, x):
        x2 = self.conv(x)
        x1 = (x2 + self.residual(x))
        return x1, x2


class up_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.BatchNorm2d(ch_in),
            nn.LeakyReLU(0.1),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding='same'),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(0.1),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, padding='same'),
        )
        self.residual = nn.Sequential(
            nn.ConvTranspose2d(ch_in, ch_out, kernel_size=2, stride=2)
        )

    def forward(self, x, y):
        x3 = torch.cat((x, y), dim=1)
        x2 = self.conv(x3)
        x1 = (x2 + self.residual(x3))
        return x1


class res_U_Net(nn.Module):
    def __init__(self, img_ch=5, output_ch=5):
        super(res_U_Net, self).__init__()
        self.basicx1 = basic_block(img_ch, 32)
        self.downx1 = down_block(32, 32)
        self.downx2 = down_block(32, 64)
        self.downx4 = down_block(64, 128)
        self.downx16 = down_block(128, 256)
        self.downx32 = down_block(256, 512)
        self.downx64 = down_block(512, 512)
        self.downx128 = down_block(512, 1024)
        self.centerx128 = basic_block(1024, 1024)
        self.upx64 = up_block(2048, 1024)
        self.upx32 = up_block(1536, 512)
        self.upx16 = up_block(1024, 512)
        self.upx8 = up_block(768, 256)
        self.upx4 = up_block(384, 128)
        self.upx2 = up_block(192, 64)
        self.upx1 = up_block(96,32)
        self.conv = nn.Conv2d(32, output_ch, kernel_size=3,padding='same')

    def forward(self, x):
        basicx1 = self.basicx1(x)
        downx1, downx1_skip = self.downx1(basicx1)
        downx2, downx2_skip = self.downx2(downx1)
        downx4, downx4_skip = self.downx4(downx2)
        downx16, downx16_skip = self.downx16(downx4)
        downx32, downx32_skip = self.downx32(downx16)
        downx64, downx64_skip = self.downx64(downx32)
        downx128, downx128_skip = self.downx128(downx64)
        centerx128 = self.centerx128(downx128)
        upx64 = self.upx64(centerx128, downx128_skip)
        upx32 = self.upx32(upx64, downx64_skip)
        upx16 = self.upx16(upx32, downx32_skip)
        upx8 = self.upx8(upx16, downx16_skip)
        upx4 = self.upx4(upx8, downx4_skip)
        upx2 = self.upx2(upx4, downx2_skip)
        upx1 = self.upx1(upx2, downx1_skip)
        out = self.conv(upx1)
        return out


import torchsummary

torchsummary.summary(res_U_Net(img_ch=5, output_ch=5), input_size=[(5, 256, 256)], batch_size=2,
                     device="cpu")

Unet+残差网络,可用与分类或预测,从tensorflow改写而来。

另外还有一点疑惑,tensorflow中的activation=linear是否等同于leakyrelu(negative_slope = 1)

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值