基于python实现resnet_基于ResNet的联合国Pytorch实现

本文介绍如何使用PyTorch实现一个ResNet-UNet模型,包括double_conv和root_block模块的定义,以及ResNetUNet类的详细结构。模型包括下采样和上采样的卷积层,通过跳跃连接实现特征融合。
摘要由CSDN通过智能技术生成

我想实现一个基于ResNet的UNet进行细分(无需预先培训)。我已经用Keras提到了这个实现,但是我的项目是用PyTorch实现的,我不确定我是否做了正确的事情。在

我的Pythorch实现(我不确定我是否正确……)我们将非常感谢您的任何建议。在def double_conv(in_channels, out_channels):

return nn.Sequential(

nn.Conv2d(in_channels, out_channels, 3, padding=1),

nn.BatchNorm2d(out_channels),

nn.ReLU(inplace=True),

nn.Conv2d(out_channels, out_channels, 3, padding=1),

nn.BatchNorm2d(out_channels),

nn.ReLU(inplace=True)

)

def root_block(in_channels, out_channels):

return nn.Sequential(

nn.Conv2d(in_channels, out_channels, 3, padding=1),

nn.BatchNorm2d(out_channels),

nn.ReLU(inplace=True),

nn.Conv2d(out_channels, out_channels, 3, padding=1),

nn.BatchNorm2d(out_channels),

)

# Define the UNet architecture

class ResNetUNet(nn.Module):

def __init__(self, n_class):

super().__init__()

self.dconv_down1 = double_conv(3, 64)

self.dconv_down11 = root_block(64, 64)

self.dconv_down2 = double_conv(64, 128)

self.dconv_down21 = root_block(128, 128)

self.dconv_down3 = double_conv(128, 256)

self.dconv_down31 = root_block(256, 256)

self.dconv_down4 = double_conv(256, 512)

self.dconv_down41 = root_block(512, 512)

self.maxpool = nn.MaxPool2d(2)

self.relu = nn.ReLU(inplace=True)

self.dconv_up3 = double_conv(256 + 512, 256)

self.dconv_up31 = root_block(256, 256)

self.dconv_up2 = double_conv(128 + 256, 128)

self.dconv_up21 = root_block(128, 128)

self.dconv_up1 = double_conv(128 + 64, 64)

self.dconv_up11 = root_block(64, 64)

self.conv_last = nn.Conv2d(64, n_class, 1)

def forward(self, x):

conv1 = self.dconv_down1(x)

x = self.dconv_down11(conv1)

x += conv1

x = self.relu(x)

x = self.maxpool(x)

conv2 = self.dconv_down2(x)

x = self.dconv_down21(conv2)

x += conv2

x = self.relu(x)

x = self.maxpool(x)

conv3 = self.dconv_down3(x)

x = self.dconv_down31(conv3)

x += conv3

x = self.relu(x)

x = self.maxpool(x)

conv4 = self.dconv_down4(x)

x = self.dconv_down41(conv4)

x += conv4

x = self.relu(x)

deconv3 = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)

deconv3 = torch.cat([deconv3, conv3], dim=1)

uconv3 = self.dconv_up3(deconv3)

x = self.dconv_up31(uconv3)

x += uconv3

x = self.relu(x)

deconv2 = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)

deconv2 = torch.cat([deconv2, conv2], dim=1)

uconv2 = self.dconv_up2(deconv2)

x = self.dconv_up21(uconv2)

x += uconv2

x = self.relu(x)

deconv1 = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)

deconv1 = torch.cat([deconv1, conv1], dim=1)

uconv1 = self.dconv_up1(deconv1)

x = self.dconv_up11(uconv1)

x += uconv1

x = self.relu(x)

out = self.conv_last(x)

return out

基于Keras的实现如下:

^{pr2}$

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值