作者在视频中跑的是mobilnet模型,这里我们尝试跑一下res50+fpn的模型
1. create_model
这个类是定义模型的部分。
这里需要注意的是
backbone = resnet50_fpn_backbone()
会自动的冻结部分底层权重
代码如下(示例):
def create\_model(num_classes):
backbone = resnet50_fpn_backbone()
# 训练自己数据集时不要修改这里的91,修改的是传入的num\_classes参数
model = FasterRCNN(backbone=backbone, num_classes=91)
# 载入预训练模型权重
# https://download.pytorch.org/models/fasterrcnn\_resnet50\_fpn\_coco-258fb6c6.pth
weights_dict = torch.load("./backbone/fasterrcnn\_resnet50\_fpn\_coco.pth")
missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
if len(missing_keys) != 0 or len(unexpected_keys) != 0:
print("missing\_keys: ", missing_keys)
print("unexpected\_keys: ", unexpected_keys)
# get number of input features for the