修改预训练权重类别数
import os
import torch
import argparse
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument("--org_path", type=str, help="the path of pretrained model")
parser.add_argument("--num_classes", type=int, default=26, help="number of classes")
return parser.parse_args()
def modify_cascade_rcnn(model_coco, num_classes):
model_coco["state_dict"]["bbox_head.0.fc_cls.we
mmdetection 修改预训练模型权重类别数
最新推荐文章于 2024-08-12 16:25:03 发布