yolov4调整网络结构

最近在用Yolov4跑车辆,行人等9种类别的检测,尝试修改网络结构,记录下遇到的一些问题…

结合Netron,更清晰

  • 解析网络结构:
def parse_model_cfg(path):
    # Parses the yolo-v3 layer configuration file and returns module definitions
    file = open(path, 'r')
    lines = file.read().split('\n')
    lines = [x for x in lines if x and not x.startswith('#')]
    lines = [x.rstrip().lstrip() for x in lines]  # get rid of fringe whitespaces
    mdefs = []  # module definitions
    for i in range(len(lines)):
        line = lines[i]
    # for line in lines:
        if line.startswith('['):  # This marks the start of a new block
            mdefs.append({})
            mdefs[-1]['type'] = line[1:-1].rstrip()
            if mdefs[-1]['type'] == 'convolutional':
                mdefs[-1]['batch_normalize'] = 0  # pre-populate with zeros (may be overwritten later)
        else:
            key, val = line.split("=")
            key = key.rstrip()

            if 'anchors' in key:
                mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2))  # np anchors
            else:
                mdefs[-1][key] = val.strip()
    # print(len(mdefs))
    return mdefs
  • 将网络结构转换为dict保存
def model2dict(mdefs):
    mdef_dict = dict()
    if mdefs[0]['type'] == 'net':
        netinfo = mdefs.pop(0)
    half_filters = False
    mdefss = copy.deepcopy(mdefs)
    for idx, model in enumerate(mdefss):
        model['idx'] = idx
        mdefs[idx]['idx'] = str(idx)
        if half_filters:
            if 'filters' in mdefs[idx].keys():
                if int(mdefs[idx]['filters']) % 2 == 1:
                    continue
                mdefs[idx]['filters'] = str(int(float(mdefs[idx]['filters'])/2))
        # # convert to:idx > 0
        # if 'from' in model.keys():
        #     froml = int(model['from'])
        #     froml = str(froml + idx) if froml < 0 else str(froml)
        #     model['from'] = froml
        # elif 'layers' in model.keys():
        #     model['layers'] = ','.join(
        #         [str(int(x)) if int(x) > 0 else str(int(x) + idx) for x in model['layers'].split(',')])
        mdef_dict[idx] = model
    mdefs.insert(0, netinfo)
    for mm in mdefs:
        if mm['type'] == 'yolo':
            mm['anchors'] = mm['anchors'].tolist()
        print('{},'.format(mm))
    return mdef_dict, mdefss, mdefs

def write2yololist(mdefs):
    with open('yolodict.py', 'w') as f:
        lines = []
        lines.append("yolo_list= [")
        # lines.append("\n[")
        for mm in mdefs:
            if mm['type'] == 'yolo':
                mm['anchors'] = mm['anchors']
            lines.append('{},\n'.format(mm))
        lines.append(']\n')
        f.write(''.join(lines))
  • 检查修改的网络结构是否正确
def check_simplely_cfg(yolo_list, cls_count=8):
    input = 512
    filters = 0
    temp_inout = []
    for idx, layer in enumerate(yolo_list):
        if layer['type'] == 'convolutional':

            filter_size = int(layer['size'])
            stride = int(layer['stride'])
            padding = int(layer['pad'])
            padding = (filter_size - 1) // 2 if padding else 0
            output = int((input - filter_size + 2*padding)/stride + 1)
            # output1 = int((input[1] - filter_size + 2*padding)/stride + 1)
            filters = int(layer['filters'])
            temp_inout.append([idx, input, output, filters])
            # input = output
        elif layer['type'] == 'shortcut':
            shortcut = temp_inout[int(layer['from'])]
            if np.mean(np.array(shortcut[2:]) == np.array(temp_inout[-1][2:])) == 1:
                new_temp = copy.deepcopy(temp_inout[-1])
                new_temp[0] = idx
                temp_inout.append(new_temp)
                output = new_temp[-2]
            else:
                print('short cut not match: \nlast is :', temp_inout[int(layer['from'])])
                print('shortcut is:', shortcut)
                return
        elif layer['type'] == 'route':
            froms = [int(i) for i in layer['layers'].split(',')]
            # froms.append(-1)
            # froms.sort()
            # froms = list(set(froms))
            fromsl = np.array([temp_inout[i] for i in froms])
            shape = fromsl.shape
            for idxx in range(2, shape[1]-1):
                if fromsl[:, idxx].max() == fromsl[:, idxx].min():
                    pass
                else:
                    print('route not match!')
                    print(fromsl)
                    return
            new_temp = copy.deepcopy(list(fromsl[-1]))
            new_temp[0] = idx
            new_temp[1] = new_temp[2]
            output = new_temp[2]
            new_temp[-1] = fromsl[:, -1].sum()
            # print('shortcut not match!')
            temp_inout.append(new_temp)
        elif layer['type'] == 'maxpool':
            if int(layer['stride']) == 1:
                new_temp = copy.deepcopy(temp_inout[-1])
                new_temp[0] = idx
                temp_inout.append(new_temp)
                output = new_temp[-2]
            else:
                stride = int(layer['stride'])
                filter_size = int(layer['size'])
                pad = int(layer['pad'])
                padding = (filter_size - 1) // 2 if pad else 0
                output = int((input - filter_size + 2 * padding) / stride + 1)
                # output = int((input - filter_size) / stride + 1)
                temp_inout.append([idx, input, output, filters])
        elif layer['type'] == 'upsample':
            output = 2*input
            temp_inout.append([idx, input, output, temp_inout[-1][-1]])
        elif layer['type'] == 'yolo':
            layer['anchors'] = np.array(layer['anchors'])
            layer['classes'] = str(cls_count)
            yolo_list[idx-1]['head'] = 1
            yolo_list[idx - 1]['filters'] = str((cls_count+5)*3)
        input = output

    print('right~~~~')
  • 将修改的网络重新写成默认的格式
def conve_cfg(md_list):
    half_filters = False
    lines = ''
    for i in range(len(md_list)):
        layer = md_list[i]
    
        # print(i, layer)
        
        lines += '[' + str(layer['type']) + ']' + '\n'
        for layer_key in layer:
            if layer_key == 'anchors':

                line_anchor = ''
                anchor_shape = np.array(layer[layer_key]).shape
                for i in range(anchor_shape[0]):
                    for j in range(anchor_shape[1]):
                        
                        line_anchor += str(int(layer['anchors'][i][j])) + ', '
                line_anchor = 'anchors = ' + line_anchor[:-2] + '\n'
                lines += line_anchor
            elif layer_key == 'filters':
                if half_filters and int(float(layer['filters'])) % 2 == 0:
                    channels = int(float(layer['filters']) / 2)
                else:
                    channels = int(float(layer['filters']))
                lines += 'filters=' + str(channels) + '\n'
            elif layer_key != 'type' and layer_key != 'anchors' and layer_key != 'filters':
                lines += str(layer_key) + '=' + str(layer[layer_key]) + '\n'
        lines += '\n'
    return lines
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
需要学习Windows系统YOLOv4的同学请前往《Windows版YOLOv4目标检测实战:原理与源码解析》,课程链接 https://edu.csdn.net/course/detail/29865【为什么要学习这门课】 Linux创始人Linus Torvalds有一句名言:Talk is cheap. Show me the code. 冗谈不够,放码过来!  代码阅读是从基础到提高的必由之路。尤其对深度学习,许多框架隐藏了神经网络底层的实现,只能在上层调包使用,对其内部原理很难认识清晰,不利于进一步优化和创新。YOLOv4是最近推出的基于深度学习的端到端实时目标检测方法。YOLOv4的实现darknet是使用C语言开发的轻型开源深度学习框架,依赖少,可移植性好,可以作为很好的代码阅读案例,让我们深入探究其实现原理。【课程内容与收获】 本课程将解析YOLOv4的实现原理和源码,具体内容包括:- YOLOv4目标检测原理- 神经网络及darknet的C语言实现,尤其是反向传播的梯度求解和误差计算- 代码阅读工具及方法- 深度学习计算的利器:BLAS和GEMM- GPU的CUDA编程方法及在darknet的应用- YOLOv4的程序流程- YOLOv4各层及关键技术的源码解析本课程将提供注释后的darknet的源码程序文件。【相关课程】 除本课程《YOLOv4目标检测:原理与源码解析》外,本人推出了有关YOLOv4目标检测的系列课程,包括:《YOLOv4目标检测实战:训练自己的数据集》《YOLOv4-tiny目标检测实战:训练自己的数据集》《YOLOv4目标检测实战:人脸口罩佩戴检测》《YOLOv4目标检测实战:国交通标志识别》建议先学习一门YOLOv4实战课程,对YOLOv4的使用方法了解以后再学习本课程。【YOLOv4网络模型架构图】 下图由白勇老师绘制  
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值