模型结构
def downsample_basic_block(x, planes, stride):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(
out.size(0), planes - out.size(1), out.size(2), out.size(3),
out.size(4)).zero_()
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class ResNeXtBottleneck(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, cardinality, stride=1,
downsample=None,conv3d_bias=True):
super(ResNeXtBottleneck, self).__init__()
mid_planes = cardinality * int(planes / 32)
self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=conv3d_bias)
self.bn1 = nn.BatchNorm3d(mid_planes)
self.conv2 = nn.Conv3d(
mid_planes,
mid_planes,
kernel_size=3,
stride=stride,
padding=1,
groups=cardinality,
bias=conv3d_bias)
self.bn2 = nn.BatchNorm3d(mid_planes)
self.conv3 = nn.Conv3d(
mid_planes, planes * self.expansion, kernel_size=1, bias=conv3d_bias)
self.bn3 = nn.BatchNorm3d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out =