实现片段
from torch import nn
import torch
class u_dowm(nn.Module):
def __init__(self,in_channel,out_channel):
super(u_dowm, self).__init__()
self.convANDbn=nn.Sequential(
nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
)
self.downsample=nn.Sequential(
nn.MaxPool2d(2,2)
)
def forward(self,x):
output=self.convANDbn(x) #output是与上采样的拼接的
out=self.downsample(output)
return output,out
class up(nn.Module):
def