1 残差网络结构
2 实现代码
class ResidualBlock(torch.nn.Module):
def __init__(self,channels):
super(ResidualBlock,self).__init__()
self.channels=channels
self.conv1=torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
self.conv2=torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
def forward(self,x):
x=F.relu(self.conv1(x))
y=self.conv2(x)
return F.relu(x+y)
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1=torch.nn.Conv2d(1,16,kernel_size=5)
self.conv2=torch.nn.Conv2d(16,32,kernel_size=5)