最近在用Yolov4跑车辆,行人等9种类别的检测,尝试修改网络结构,记录下遇到的一些问题…
结合Netron,更清晰
- 解析网络结构:
def parse_model_cfg(path):
# Parses the yolo-v3 layer configuration file and returns module definitions
file = open(path, 'r')
lines = file.read().split('\n')
lines = [x for x in lines if x and not x.startswith('#')]
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
mdefs = [] # module definitions
for i in range(len(lines)):
line = lines[i]
# for line in lines:
if line.startswith('['): # This marks the start of a new block
mdefs.append({})
mdefs[-1]['type'] = line[1:-1].rstrip()
if mdefs[-1]['type'] == 'convolutional':
mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
else:
key, val = line.split("=")
key = key.rstrip()
if 'anchors' in key:
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors
else:
mdefs[-1][key] = val.strip()
# print(len(mdefs))
return mdefs
- 将网络结构转换为dict保存
def model2dict(mdefs):
mdef_dict = dict()
if mdefs[0]['type'] == 'net':
netinfo = mdefs.pop(0)
half_filters = False
mdefss = copy.deepcopy(mdefs)
for idx, model in enumerate(mdefss):
model['idx'] = idx
mdefs[idx]['idx'] = str(idx)
if half_filters:
if 'filters' in mdefs[idx].keys():
if int(mdefs[idx]['filters']) % 2 == 1:
continue
mdefs[idx]['filters'] = str(int(float(mdefs[idx]['filters'])/2))
# # convert to:idx > 0
# if 'from' in model.keys():
# froml = int(model['from'])
# froml = str(froml + idx) if froml < 0 else str(froml)
# model['from'] = froml
# elif 'layers' in model.keys():
# model['layers'] = ','.join(
# [str(int(x)) if int(x) > 0 else str(int(x) + idx) for x in model['layers'].split(',')])
mdef_dict[idx] = model
mdefs.insert(0, netinfo)
for mm in mdefs:
if mm['type'] == 'yolo':
mm['anchors'] = mm['anchors'].tolist()
print('{},'.format(mm))
return mdef_dict, mdefss, mdefs
def write2yololist(mdefs):
with open('yolodict.py', 'w') as f:
lines = []
lines.append("yolo_list= [")
# lines.append("\n[")
for mm in mdefs:
if mm['type'] == 'yolo':
mm['anchors'] = mm['anchors']
lines.append('{},\n'.format(mm))
lines.append(']\n')
f.write(''.join(lines))
- 检查修改的网络结构是否正确
def check_simplely_cfg(yolo_list, cls_count=8):
input = 512
filters = 0
temp_inout = []
for idx, layer in enumerate(yolo_list):
if layer['type'] == 'convolutional':
filter_size = int(layer['size'])
stride = int(layer['stride'])
padding = int(layer['pad'])
padding = (filter_size - 1) // 2 if padding else 0
output = int((input - filter_size + 2*padding)/stride + 1)
# output1 = int((input[1] - filter_size + 2*padding)/stride + 1)
filters = int(layer['filters'])
temp_inout.append([idx, input, output, filters])
# input = output
elif layer['type'] == 'shortcut':
shortcut = temp_inout[int(layer['from'])]
if np.mean(np.array(shortcut[2:]) == np.array(temp_inout[-1][2:])) == 1:
new_temp = copy.deepcopy(temp_inout[-1])
new_temp[0] = idx
temp_inout.append(new_temp)
output = new_temp[-2]
else:
print('short cut not match: \nlast is :', temp_inout[int(layer['from'])])
print('shortcut is:', shortcut)
return
elif layer['type'] == 'route':
froms = [int(i) for i in layer['layers'].split(',')]
# froms.append(-1)
# froms.sort()
# froms = list(set(froms))
fromsl = np.array([temp_inout[i] for i in froms])
shape = fromsl.shape
for idxx in range(2, shape[1]-1):
if fromsl[:, idxx].max() == fromsl[:, idxx].min():
pass
else:
print('route not match!')
print(fromsl)
return
new_temp = copy.deepcopy(list(fromsl[-1]))
new_temp[0] = idx
new_temp[1] = new_temp[2]
output = new_temp[2]
new_temp[-1] = fromsl[:, -1].sum()
# print('shortcut not match!')
temp_inout.append(new_temp)
elif layer['type'] == 'maxpool':
if int(layer['stride']) == 1:
new_temp = copy.deepcopy(temp_inout[-1])
new_temp[0] = idx
temp_inout.append(new_temp)
output = new_temp[-2]
else:
stride = int(layer['stride'])
filter_size = int(layer['size'])
pad = int(layer['pad'])
padding = (filter_size - 1) // 2 if pad else 0
output = int((input - filter_size + 2 * padding) / stride + 1)
# output = int((input - filter_size) / stride + 1)
temp_inout.append([idx, input, output, filters])
elif layer['type'] == 'upsample':
output = 2*input
temp_inout.append([idx, input, output, temp_inout[-1][-1]])
elif layer['type'] == 'yolo':
layer['anchors'] = np.array(layer['anchors'])
layer['classes'] = str(cls_count)
yolo_list[idx-1]['head'] = 1
yolo_list[idx - 1]['filters'] = str((cls_count+5)*3)
input = output
print('right~~~~')
- 将修改的网络重新写成默认的格式
def conve_cfg(md_list):
half_filters = False
lines = ''
for i in range(len(md_list)):
layer = md_list[i]
# print(i, layer)
lines += '[' + str(layer['type']) + ']' + '\n'
for layer_key in layer:
if layer_key == 'anchors':
line_anchor = ''
anchor_shape = np.array(layer[layer_key]).shape
for i in range(anchor_shape[0]):
for j in range(anchor_shape[1]):
line_anchor += str(int(layer['anchors'][i][j])) + ', '
line_anchor = 'anchors = ' + line_anchor[:-2] + '\n'
lines += line_anchor
elif layer_key == 'filters':
if half_filters and int(float(layer['filters'])) % 2 == 0:
channels = int(float(layer['filters']) / 2)
else:
channels = int(float(layer['filters']))
lines += 'filters=' + str(channels) + '\n'
elif layer_key != 'type' and layer_key != 'anchors' and layer_key != 'filters':
lines += str(layer_key) + '=' + str(layer[layer_key]) + '\n'
lines += '\n'
return lines