Pytorch 目标检测通过fssd训练自己的数据集

Pytorch 目标检测通过fssd训练自己的数据集

参考文章
代码参考:fssd-pytorch实现代码

1、下载代码

2、整理自己的数据集

  1. 根据VOC的格式整理自己的数据集,xml文件放到Annotations文件夹下面。然后Main放自己的数据集划分,格式如下面的val.txt格式。
    在这里插入图片描述
    val.txt格式
  2. 更改一些数据集的代码
    因为只用到一个数据集,所以不需要将两个数据集都弄上,否则检查会报错。就在train.py那里进行更改即可。根据数据集命名方式进行更改即可。
    在这里插入图片描述
    在这里插入图片描述

3、选用VOC格式将COCO相关内容进行注释

在train.py中

在这里插入图片描述在这里插入图片描述
在这里插入图片描述在data.init_.py中
在这里插入图片描述

4、在data/voc0712.py中将检测类型改成自己要检测的类型

5、修改FSSD——VGG.py文件中的build_net的num_class修改为自己的种类(+背景)

6、因为pytorch版本问题,把VGG中的RELU(inplace)进行修改:

RELU(inplace=True)
RELU(inplace=False)

7、使用预训练模型进行训练

  1. 如果是只使用vgg16作为预训练模型,那么resume_net这里就不用管。
  2. 如果用别的预训练模型,如代码网站上的已经训练好的作为预训练模型,那么就得更改一些地方。

在这里插入图片描述
更改参数resume:
在这里插入图片描述

因为权重文件上有不能兼容模型的输入和输出的部分,所以需要定义交叉取值函数。
在这里插入图片描述

定义网络结构交叉取值函数

# 关键自定义函数

def intersect_dicts(da, db, exclude=()):
    """输入参数
    da (state_dict)			 加载权重的 state_dict
    db (state_dict) 	 	 加载模型的 state_dict
    exclude (list)           不想要的权重 keys()

    返回参数
    加载的部分权重 (state_dict)
    """
    '''
    print("exclude",exclude)
    for k, v in da.items():
        for x in exclude:
            if x in k:
                print('@ ',x ,k)
            if v.shape != db[k].shape:
                print('# ', x, k)
	'''

    return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}

else:
    def xavier(param):
        init.xavier_uniform(param)


    def weights_init(m):
        for key in m.state_dict():
            if key.split('.')[-1] == 'weight':
                if 'conv' in key:
                    # init.kaiming_normal(m.state_dict()[key], mode='fan_out')
                    init.kaiming_normal_(m.state_dict()[key], mode='fan_out')
                if 'bn' in key:
                    m.state_dict()[key][...] = 1
            elif key.split('.')[-1] == 'bias':
                m.state_dict()[key][...] = 0

# 这里先对 不需要从权重载入的部分进行初始化。
    print('Initializing weights...')
    # initialize newly added layers' weights with kaiming_normal method
    net.extras.apply(weights_init)
    net.loc.apply(weights_init)
    net.conf.apply(weights_init)
    net.ft_module.apply(weights_init)
    net.pyramid_ext.apply(weights_init)


    print('Loading resume network')
    # state_dict = torch.load(resume_net)  # 这个地方进行重载新的网络
    state_dict = torch.load(args.resume_net)  # 这个地方进行重载新的网络
    # create new OrderedDict that does not contain `module.`

    print('net',net) # 先看看net长什么样子

    from collections import OrderedDict

    new_state_dict = OrderedDict()




    for k, v in state_dict.items():
        print('key: ',k)
        # print('value: ',v)
        head = k[:7]
        if head == 'module.':
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v


    # 权重取舍处理
    net_dict1234=net.state_dict()
    state_dict1234 = intersect_dicts(new_state_dict, net_dict1234, exclude=['loc.4.weight','loc.4.bias','conf.4.weight','conf.4.bias']) # 将不兼容的部分排除掉,不进行载入
    # net.load_state_dict(new_state_dict)  # 这里改造一下,不用一次性全部输入,而是一个一个键的输入即可
    net.load_state_dict(state_dict1234,strict=False)  # 这里改造一下,不用一次性全部输入,而是一个一个键的输入即可

然后就可以训练自己的模型了。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值