def get_data(input_path): all_imgs = {} classes_count = {} class_mapping = {} with open(input_path,'r') as f: print('Parsing annotation files') for line in f: line_split = line.strip().split(',') (filename,x1,y1,x2,y2,class_name) = line_split if class_name not in classes_count: classes_count[class_name] = 1 else: classes_count[class_name] += 1 if class_name not in class_mapping: class_mapping[class_name] = len(class_mapping) if filename not in all_imgs: all_imgs[filename] = {} img = cv2.imread(filename) (rows,cols) = img.shape[:2] all_imgs[filename]['filepath'] = filename all_imgs[filename]['width'] = cols all_imgs[filename]['height'] = rows all_imgs[filename]['bboxes'] = [] if np.random.randint(0,6) > 0: all_imgs[filename]['imageset'] = 'trainval' else: all_imgs[filename]['imageset'] = 'test' all_imgs[filename]['bboxes'].append({'class': class_name, 'x1': int(x1), 'x2': int(x2), 'y1': int(y1), 'y2': int(y2)}) all_data = [] for key in all_imgs: all_data.append(all_imgs[key]) classes_count['bg'] = 0 class_mapping['bg'] = len(class_mapping) random.shuffle(all_data) print('Training images per class ({} classes) :'.format(len(classes_count))) pprint.pprint(classes_count) return all_data, classes_count, class_mapping
代码很简单,从命令行读出input_path,然后分割出来文件路径和框的信息以及后边那个字符串标注,分别记下来各种信息,随机标注上市训练数据还是测试数据,返回
def classifier(base_layers, input_rois, num_rois, nb_classes = 21, trainable=False): pooling_regions = 14 input_shape = (num_rois,14,14,1024) out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois]) out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)将input_shape(num_rois默认设置为4,整个为4,14,14,1024)和out_roi_pool放入分类器里,
关于timedistributed参看:
http://blog.csdn.net/xiaojiajia007/article/details/76665016