PVAnet Mindspore迁移


训练策略

在VOC2007 test上评估:
训练集: COCO trainval + VOC2007 trainval + VOC2012 trainval, 并进一步在VOC2007 trainval + VOC2012 trainval 上 finetune;刚开始是80类训练,finetune是20类训练;
在VOC20012 test上评估:
训练集: COCO trainval + VOC2007 trainval + VOC2012 trainval, 并进一步在VOC2007 trainval /test+ VOC2012 trainval 上 finetune;

看到这个训练策略我人傻了,就现在COCO2017上训练和验证了------因为有现成的o(╥﹏╥)o

数据集介绍

C0C02017数据集

MS COCO的全称是Microsoft Common Objects in Context,是微软团队提供的一个可以用来进行图像识别的数据集。

其数据集主要包括有标注的和无标注的数据:

  • 2014:训练集 + 验证集 + 测试集
  • 2015:测试集
  • 2017:训练集 + 验证集 + 测试集 + 未标注

COCO2017

数据集图片数量
train118287
val5000
test40670

COCO数据集的标签文件 .json

COCO有5种类型的标注,分别是:物体检测、关键点检测、实例分割、全景分割、图片标注,都是对应一个json文件。.json 文件的本质是一个字典。

读取文件

import json

import json

filedir = '\instances_val2017.json'

annos = json.loads(open(filedir).read())
print(type(annos))  # <class 'dict'>
print(len(annos))  # 5
print(annos.keys())  # dict_keys(['info', 'licenses', 'images', 'annotations', 'categories'])
print(annos['info'])  # {'description': 'COCO 2017 Dataset', 'url': 'http://cocodataset.org', 'version': '1.0', 'year': 2017, 'contributor': 'COCO Consortium', 'date_created': '2017/09/01'}
print(annos['licenses'])
print(annos['images'])
print(annos['annotations'])
print(annos['categories'])

licenses 对应键值的内容为:

{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}
{'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}
{'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}
{'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}
{'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}
{'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}
{'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}
{'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}

images 对应键值的 部分内容为:

{'license': 3,
 'file_name': 'COCO_val2014_000000016744.jpg',
 'coco_url': 'http://mscoco.org/images/16744',
 'height': 335,
 'width': 500,
 'date_captured': '2013-11-20 14:29:03',
 'flickr_url': 'http://farm3.staticflickr.com/2393/2228750191_11de3ec047_z.jpg',
 'id': 16744
 },
.....不断的重复
其他相同格式的数据

categories 对应键值的内容为:

print(annos['annotations'].keys())
print(annos['annotations'][0])
dict_keys(['segmentation', 'area', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])
{'segmentation': [[510.66, 423.01, 511.72, 420.03, 510.45, 416.0, 510.34, 413.02, 510.77, 410.26, 510.77, 407.5, 510.34, 405.16, 511.51, 402.83, 511.41, 400.49, 510.24, 398.16, 509.39, 397.31, 504.61, 399.22, 502.17, 399.64, 500.89, 401.66, 500.47, 402.08, 499.09, 401.87, 495.79, 401.98, 490.59, 401.77, 488.79, 401.77, 485.39, 398.58, 483.9, 397.31, 481.56, 396.35, 478.48, 395.93, 476.68, 396.03, 475.4, 396.77, 473.92, 398.79, 473.28, 399.96, 473.49, 401.87, 474.56, 403.47, 473.07, 405.59, 473.39, 407.71, 476.68, 409.41, 479.23, 409.73, 481.56, 410.69, 480.4, 411.85, 481.35, 414.93, 479.86, 418.65, 477.32, 420.03, 476.04, 422.58, 479.02, 422.58, 480.29, 423.01, 483.79, 419.93, 486.66, 416.21, 490.06, 415.57, 492.18, 416.85, 491.65, 420.24, 492.82, 422.9, 493.56, 424.39, 496.43, 424.6, 498.02, 423.01, 498.13, 421.31, 497.07, 420.03, 497.07, 415.15, 496.33, 414.51, 501.1, 411.96, 502.06, 411.32, 503.02, 415.04, 503.33, 418.12, 501.1, 420.24, 498.98, 421.63, 500.47, 424.39, 505.03, 423.32, 506.2, 421.31, 507.69, 419.5, 506.31, 423.32, 510.03, 423.01, 510.45, 423.01]], 'area': 702.1057499999998, 'iscrowd': 0, 'image_id': 289343, 'bbox': [473.07, 395.93, 38.65, 28.67], 'category_id': 18, 'id': 1768}

Process finished with exit code 0


categories 对应键值的内容为:

{'supercategory': 'person', 'id': 1, 'name': 'person'}
{'supercategory': 'vehicle', 'id': 2, 'name': 'bicycle'}
{'supercategory': 'vehicle', 'id': 3, 'name': 'car'}
{'supercategory': 'vehicle', 'id': 4, 'name': 'motorcycle'}
{'supercategory': 'vehicle', 'id': 5, 'name': 'airplane'}
{'supercategory': 'vehicle', 'id': 6, 'name': 'bus'}
{'supercategory': 'vehicle', 'id': 7, 'name': 'train'}
{'supercategory': 'vehicle', 'id': 8, 'name': 'truck'}
{'supercategory': 'vehicle', 'id': 9, 'name': 'boat'}
{'supercategory': 'outdoor', 'id': 10, 'name': 'traffic light'}
{'supercategory': 'outdoor', 'id': 11, 'name': 'fire hydrant'}
{'supercategory': 'outdoor', 'id': 13, 'name': 'stop sign'}
{'supercategory': 'outdoor', 'id': 14, 'name': 'parking meter'}
{'supercategory': 'outdoor', 'id': 15, 'name': 'bench'}
{'supercategory': 'animal', 'id': 16, 'name': 'bird'}
{'supercategory': 'animal', 'id': 17, 'name': 'cat'}
{'supercategory': 'animal', 'id': 18, 'name': 'dog'}
{'supercategory': 'animal', 'id': 19, 'name': 'horse'}
{'supercategory': 'animal', 'id': 20, 'name': 'sheep'}
{'supercategory': 'animal', 'id': 21, 'name': 'cow'}
{'supercategory': 'animal', 'id': 22, 'name': 'elephant'}
{'supercategory': 'animal', 'id': 23, 'name': 'bear'}
{'supercategory': 'animal', 'id': 24, 'name': 'zebra'}
{'supercategory': 'animal', 'id': 25, 'name': 'giraffe'}
{'supercategory': 'accessory', 'id': 27, 'name': 'backpack'}
{'supercategory': 'accessory', 'id': 28, 'name': 'umbrella'}
{'supercategory': 'accessory', 'id': 31, 'name': 'handbag'}
{'supercategory': 'accessory', 'id': 32, 'name': 'tie'}
{'supercategory': 'accessory', 'id': 33, 'name': 'suitcase'}
{'supercategory': 'sports', 'id': 34, 'name': 'frisbee'}
{'supercategory': 'sports', 'id': 35, 'name': 'skis'}
{'supercategory': 'sports', 'id': 36, 'name': 'snowboard'}
{'supercategory': 'sports', 'id': 37, 'name': 'sports ball'}
{'supercategory': 'sports', 'id': 38, 'name': 'kite'}
{'supercategory': 'sports', 'id': 39, 'name': 'baseball bat'}
{'supercategory': 'sports', 'id': 40, 'name': 'baseball glove'}
{'supercategory': 'sports', 'id': 41, 'name': 'skateboard'}
{'supercategory': 'sports', 'id': 42, 'name': 'surfboard'}
{'supercategory': 'sports', 'id': 43, 'name': 'tennis racket'}
{'supercategory': 'kitchen', 'id': 44, 'name': 'bottle'}
{'supercategory': 'kitchen', 'id': 46, 'name': 'wine glass'}
{'supercategory': 'kitchen', 'id': 47, 'name': 'cup'}
{'supercategory': 'kitchen', 'id': 48, 'name': 'fork'}
{'supercategory': 'kitchen', 'id': 49, 'name': 'knife'}
{'supercategory': 'kitchen', 'id': 50, 'name': 'spoon'}
{'supercategory': 'kitchen', 'id': 51, 'name': 'bowl'}
{'supercategory': 'food', 'id': 52, 'name': 'banana'}
{'supercategory': 'food', 'id': 53, 'name': 'apple'}
{'supercategory': 'food', 'id': 54, 'name': 'sandwich'}
{'supercategory': 'food', 'id': 55, 'name': 'orange'}
{'supercategory': 'food', 'id': 56, 'name': 'broccoli'}
{'supercategory': 'food', 'id': 57, 'name': 'carrot'}
{'supercategory': 'food', 'id': 58, 'name': 'hot dog'}
{'supercategory': 'food', 'id': 59, 'name': 'pizza'}
{'supercategory': 'food', 'id': 60, 'name': 'donut'}
{'supercategory': 'food', 'id': 61, 'name': 'cake'}
{'supercategory': 'furniture', 'id': 62, 'name': 'chair'}
{'supercategory': 'furniture', 'id': 63, 'name': 'couch'}
{'supercategory': 'furniture', 'id': 64, 'name': 'potted plant'}
{'supercategory': 'furniture', 'id': 65, 'name': 'bed'}
{'supercategory': 'furniture', 'id': 67, 'name': 'dining table'}
{'supercategory': 'furniture', 'id': 70, 'name': 'toilet'}
{'supercategory': 'electronic', 'id': 72, 'name': 'tv'}
{'supercategory': 'electronic', 'id': 73, 'name': 'laptop'}
{'supercategory': 'electronic', 'id': 74, 'name': 'mouse'}
{'supercategory': 'electronic', 'id': 75, 'name': 'remote'}
{'supercategory': 'electronic', 'id': 76, 'name': 'keyboard'}
{'supercategory': 'electronic', 'id': 77, 'name': 'cell phone'}
{'supercategory': 'appliance', 'id': 78, 'name': 'microwave'}
{'supercategory': 'appliance', 'id': 79, 'name': 'oven'}
{'supercategory': 'appliance', 'id': 80, 'name': 'toaster'}
{'supercategory': 'appliance', 'id': 81, 'name': 'sink'}
{'supercategory': 'appliance', 'id': 82, 'name': 'refrigerator'}
{'supercategory': 'indoor', 'id': 84, 'name': 'book'}
{'supercategory': 'indoor', 'id': 85, 'name': 'clock'}
{'supercategory': 'indoor', 'id': 86, 'name': 'vase'}
{'supercategory': 'indoor', 'id': 87, 'name': 'scissors'}
{'supercategory': 'indoor', 'id': 88, 'name': 'teddy bear'}
{'supercategory': 'indoor', 'id': 89, 'name': 'hair drier'}
{'supercategory': 'indoor', 'id': 90, 'name': 'toothbrush'}

下载地址

image:
2017 Train images :http://images.cocodataset.org/zips/train2017.zip
2017 Val images : http://images.cocodataset.org/zips/val2017.zip
2017 Test images : http://images.cocodataset.org/zips/test2017.zip
2017 Unlabeled images :http://images.cocodataset.org/zips/unlabeled2017.zip

Annotations:
2017 Train/Val annotations:
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
2017 Testing Image info:
http://images.cocodataset.org/annotations/image_info_test2017.zip
2017 Unlabeled Image info:
http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip

VOC2007/2012数据集

暂时PASS

数据处理部分

Mindrecord介绍

官方介绍:将数据集转换为MindSpore数据格式
用户可以将非标准的数据集和常见的数据集转换为MindSpore数据格式,从而方便地加载到MindSpore中进行训练。同时,MindSpore在部分场景做了性能优化,使用MindSpore数据格式可以获得更好的性能体验。MindSpore数据格式具备的特征如下:

  • 实现多变的用户数据统一存储、访问,训练数据读取更简便。
  • 数据聚合存储,高效读取,且方便管理、移动
  • 高效数据编解码操作,对用户透明、无感知。
  • 灵活控制分区大小,实现分布式训练。

将数据转化成Mindrecod格式

COCO2017

以下这段代码是将COCO2017的trainMindrecord的格式保存:

def data_to_mindrecord_byte_image(config, dataset="coco", is_training=True, prefix="PVAnet.mindrecord", file_num=8):
    """Create MindRecord file."""
    mindrecord_dir = config.mindrecord_dir
    mindrecord_path = os.path.join(mindrecord_dir, prefix)
    writer = FileWriter(mindrecord_path, file_num)
    if dataset == "coco":
        image_files, image_anno_dict = create_coco_label(is_training, config=config)
    else:
        image_files, image_anno_dict = create_train_data_from_txt(config.image_dir, config.anno_path)

    pvanet_json = {
        "image": {"type": "bytes"},
        "annotation": {"type": "int32", "shape": [-1, 6]},
    }
    writer.add_schema(pvanet_json, "pvanet_json")

    for image_name in image_files:
        with open(image_name, 'rb') as f:
            img = f.read()
        annos = np.array(image_anno_dict[image_name], dtype=np.int32)
        row = {"image": img, "annotation": annos}
        writer.write_raw_data([row])
    writer.commit()

data_to_mindrecord_byte_image中的create_coco_label部分代码的含义:

train_cls = config.coco_classes  # coco_classes配置文件中自己写的类别(80类)
train_cls_dict = {}
for i, cls in enumerate(train_cls):
    train_cls_dict[cls] = i
print(train_cls_dict)   
{'background': 0, 'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 'airplane': 5, 'bus': 6, 'train': 7, 'truck': 8, 'boat': 9, 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 12, 'parking meter': 13, 'bench': 14, 'bird': 15, 'cat': 16, 'dog': 17, 'horse': 18, 'sheep': 19, 'cow': 20, 'elephant': 21, 'bear': 22, 'zebra': 23, 'giraffe': 24, 'backpack': 25, 'umbrella': 26, 'handbag': 27, 'tie': 28, 'suitcase': 29, 'frisbee': 30, 'skis': 31, 'snowboard': 32, 'sports ball': 33, 'kite': 34, 'baseball bat': 35, 'baseball glove': 36, 'skateboard': 37, 'surfboard': 38, 'tennis racket': 39, 'bottle': 40, 'wine glass': 41, 'cup': 42, 'fork': 43, 'knife': 44, 'spoon': 45, 'bowl': 46, 'banana': 47, 'apple': 48, 'sandwich': 49, 'orange': 50, 'broccoli': 51, 'carrot': 52, 'hot dog': 53, 'pizza': 54, 'donut': 55, 'cake': 56, 'chair': 57, 'couch': 58, 'potted plant': 59, 'bed': 60, 'dining table': 61, 'toilet': 62, 'tv': 63, 'laptop': 64, 'mouse': 65, 'remote': 66, 'keyboard': 67, 'cell phone': 68, 'microwave': 69, 'oven': 70, 'toaster': 71, 'sink': 72, 'refrigerator': 73, 'book': 74, 'clock': 75, 'vase': 76, 'scissors': 77, 'teddy bear': 78, 'hair drier': 79, 'toothbrush': 80}

这里class_dict 中类别对应的id中间有跳过,和上面类别不是一一对应,暂时不知道会对结果有什么影响。

coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())  # 上面categories对应键值的内容
for cat in cat_ids:
    classs_dict[cat["id"]] = cat["name"]
print(classs_dict)
{1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'}
image_ids = coco.getImgIds()  # 获取每张图片的id
    image_files = []
    image_anno_dict = {}
print(len(image_ids))  # 118287
print(image_ids)
[391895, 522418, 184613..., 475546]

取第一张图打印:

image_info = coco.loadImgs(391895)
print(image_info)
[{'license': 3, 'file_name': '000000391895.jpg', 'coco_url': 'http://images.cocodataset.org/train2017/000000391895.jpg', 'height': 360, 'width': 640, 'date_captured': '2013-11-14 11:18:45', 'flickr_url': 'http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg', 'id': 391895}]
anno_ids = coco.getAnnIds(imgIds=391895, iscrowd=None)  # 这张图中的bbox的ID
print(anno_ids)
[151091, 202758, 1260346, 1766676]
anno = coco.loadAnns(anno_ids)
print(anno[0].keys())
dict_keys(['segmentation', 'area', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])
print(anno[0])
{'segmentation': [[376.97, 176.91, 398.81, 176.91, 396.38, 147.78, 447.35, 146.17, 448.16, 172.05, 448.16, 178.53, 464.34, 186.62, 464.34, 192.28, 448.97, 195.51, 447.35, 235.96, 441.69, 258.62, 454.63, 268.32, 462.72, 276.41, 471.62, 290.98, 456.25, 298.26, 439.26, 292.59, 431.98, 308.77, 442.49, 313.63, 436.02, 316.86, 429.55, 322.53, 419.84, 354.89, 402.04, 359.74, 401.24, 312.82, 370.49, 303.92, 391.53, 299.87, 391.53, 280.46, 385.06, 278.84, 381.01, 278.84, 359.17, 269.13, 373.73, 261.85, 374.54, 256.19, 378.58, 231.11, 383.44, 205.22, 385.87, 192.28, 373.73, 184.19]], 'area': 12190.44565, 'iscrowd': 0, 'image_id': 391895, 'bbox': [359.17, 146.17, 112.45, 213.57], 'category_id': 4, 'id': 151091}

如果该bbox中的种类在train_cls中,将其加到annos中。anno中的bbox是(x,y,w,h),需将其转化为左上角和右下角。将bbox,class_id(0 – 80),iscrowd(0/1)这三个信息append到annos中。再将该图片对应的所有标注框信息对应到image_anno_dict上。在这里插入图片描述

image_ids = coco.getImgIds()
image_files = []
image_anno_dict = {}
for img_id in image_ids:
    image_info = coco.loadImgs(img_id)
    file_name = image_info[0]["file_name"]
    anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
    anno = coco.loadAnns(anno_ids)
    image_path = os.path.join(coco_root, data_type, file_name)
    annos = []
	for label in anno:
	     bbox = label["bbox"]
	     class_name = classs_dict[label["category_id"]]
	     if class_name in train_cls:
	         x1, x2 = bbox[0], bbox[0] + bbox[2]
	         y1, y2 = bbox[1], bbox[1] + bbox[3]
	         annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])
	if annos:
	    image_anno_dict[image_path] = np.array(annos)
	else:
	    image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])

image_files是由所有图片路径组成的列表,image_anno_dict是所有图片对应的标注框信息:

  • bbox
  • class_id(0–80)
  • iscrowd(0/1)
def create_coco_label(is_training, config):
	...
	return image_files, image_anno_dict

回到 data_to_mindrecord_byte_image,将图片(以byte格式)和annos保存到row中,然后以mindrecord输出到 mindrecord_path中。

writer = FileWriter(mindrecord_path, file_num)
pvanet_json = {
     "image": {"type": "bytes"},
     "annotation": {"type": "int32", "shape": [-1, 6]},
 }
 writer.add_schema(pvanet_json, "pvanet_json")  # Schema用于定义数据集包含哪些字段以及字段的类型
for image_name in image_files:
    with open(image_name, 'rb') as f:
        img = f.read()
    annos = np.array(image_anno_dict[image_name], dtype=np.int32)
    row = {"image": img, "annotation": annos}
    writer.write_raw_data([row])
writer.commit()

voc2007/2012

暂时PASS

生成PVAnet数据

这段有点迷,没有很搞懂内部数据怎么处理的,下面简单介绍下。

import mindspore.dataset as de

def create_PVAnet_dataset(config, mindrecord_file, batch_size=2, device_num=1, rank_id=0, is_training=True,
                          num_parallel_workers=8, python_multiprocessing=False):
    """Create PVAnet dataset with MindDataset."""
    cv2.setNumThreads(0)
    de.config.set_prefetch_size(8)
    ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], 	num_shards=device_num, shard_id=rank_id,
                        num_parallel_workers=8, shuffle=is_training)
    decode = C.Decode()
    ds = ds.map(input_columns=["image"], operations=decode)
    compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training, config=config))

    if is_training:
        ds = ds.map(input_columns=["image", "annotation"],
                    output_columns=["image", "image_shape", "box", "label", "valid_num"],
                    column_order=["image", "image_shape", "box", "label", "valid_num"],
                    operations=compose_map_func, python_multiprocessing=python_multiprocessing,
                    num_parallel_workers=num_parallel_workers)
        ds = ds.batch(batch_size, drop_remainder=True)
    else:
        ds = ds.map(input_columns=["image", "annotation"],
                    output_columns=["image", "image_shape", "box", "label", "valid_num"],
                    column_order=["image", "image_shape", "box", "label", "valid_num"],
                    operations=compose_map_func,
                    num_parallel_workers=num_parallel_workers)
        ds = ds.batch(batch_size, drop_remainder=True)
    return ds	                           

这段应该就是用MindDataset把Mindrecord加载到内存中。num_shards,shard_id,num_parallel_workers这三个参数和分布式运行有关。

ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], 	num_shards=device_num, shard_id=rank_id,
                        num_parallel_workers=8, shuffle=is_training)

这段应该是将图片解码:

decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)

lamdba: 该函数可接受任意数量的参数,但只能有一个表达式。这里应该是定义了imageannotation两个参数用来接收preprocess_fn的返回值。但是我看不懂下面ds.map()里的operation是什么意思(我猜是后面数据可以调用这个输入的函数里功能)。

compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training, config=config))
if is_training:
    ds = ds.map(input_columns=["image", "annotation"],
                output_columns=["image", "image_shape", "box", "label", "valid_num"],
                column_order=["image", "image_shape", "box", "label", "valid_num"],
                operations=compose_map_func, python_multiprocessing=python_multiprocessing,
                num_parallel_workers=num_parallel_workers)
    ds = ds.batch(batch_size, drop_remainder=True)

一个数据预处理,一个数据增强,暂时不知道用在哪。

def preprocess_fn(image, box, is_training, config):
    """Preprocess function for dataset."""
    def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
        image_shape = image_shape[:2]
        input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert

        if config.keep_ratio:
            input_data = rescale_column_test(*input_data, config=config)
        else:
            input_data = resize_column_test(*input_data, config=config)
        input_data = imnormalize_column(*input_data)

        output_data = transpose_column(*input_data)
        return output_data

    def _data_aug(image, box, is_training):
        """Data augmentation function."""
        image_bgr = image.copy()
        image_bgr[:, :, 0] = image[:, :, 2]
        image_bgr[:, :, 1] = image[:, :, 1]
        image_bgr[:, :, 2] = image[:, :, 0]
        image_shape = image_bgr.shape[:2]
        gt_box = box[:, :4]
        gt_label = box[:, 4]
        gt_iscrowd = box[:, 5]

        pad_max_number = 128
        gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0)
        gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
        gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1)
        gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)

        if not is_training:
            return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)

        flip = (np.random.rand() < config.flip_ratio)
        expand = (np.random.rand() < config.expand_ratio)
        input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert

        if expand:
            input_data = expand_column(*input_data)
        if config.keep_ratio:
            input_data = rescale_column(*input_data, config=config)
        else:
            input_data = resize_column(*input_data, config=config)
        input_data = imnormalize_column(*input_data)
        if flip:
            input_data = flip_column(*input_data)

        output_data = transpose_column(*input_data)
        return output_data

    return _data_aug(image, box, is_training)

暂时PASS

PVAnet模型介绍

参考

https://blog.csdn.net/W1995S/article/details/113123127

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值