class Backbone(nn.Module):
def __init__(self, pretrain_path=None):
super(Backbone, self).__init__()
net = resnet50()
self.out_channels = [1024, 512, 512, 256, 256, 256]
if pretrain_path is not None:
net.load_state_dict(torch.load(pretrain_path))
self.feature_extractor = nn.Sequential(*list(net.children())[:7])
conv4_block1 = self.feature_extractor[-1][0]
# 修改conv4_block1的步距,从2->1
conv4_block1.conv1.stride = (1, 1)
conv4_block1.conv2.stride = (1, 1)
conv4_block1.downsample[0].stride = (1, 1)
def forward(self, x):
x = self.feature_extractor(x)
return x
07-22