https://blog.csdn.net/sjyttkl/article/details/105052669
就是用1*1卷积一下,然后用参数拼接一下,我之前用过
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
)
self.rezero_shortcut = ReZeroShortcut(alpha=0.0)
class ReZeroShortcut(nn.Module):
def __init__(self, alpha=0.0):
super(ReZeroShortcut, self).__init__()
self.alpha = Parameter(torch.ones(1) * alpha)
self.tanh = nn.Tanh()
def forward(self, shortcut, x):
return sho