1. 基本ResNet单元:
import torch
from torch import nn
from torch.nn import functional as F
class Resnet(nn.Module):
def __init__(self, ch_in, ch_out):
super(Resnet, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
# 如果输出、输出维度不同,需转化后才能相加
if ch_out != ch_in:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
nn.BatchNorm2d(ch_out)
)