siris 显著性排序网络代码解读(inference过程)Inferring Attention Shift Ranks of Objects for Image Saliency

training 过程的代码已经解读完毕,见:siris 显著性排序网络代码解读(training过程)
本文继续解读 inference 过程(这部分比 training 简略,在上篇中写过的内容不再赘述)。
按照相同的思路,根据 README.md 文档要求运行文件的顺序解读。同样默认已经对 mask r-cnn 的结构和代码比较熟悉、阅读过显著性排序网络原论文。

第一部分 获得图片的 rois classIds scores 信息

第一部分涉及的内容

根据 README.md 指示,首先应该运行 python pre_process/object_detection.py ,来看看这个文件的源码:

首先是设置各种路径、设置模式为 inference、获得 config 对象

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your location

if __name__ == '__main__':
    # add pre-trained weight path - backbone pre-trained on salient objects (binary, no rank)
    weight_path = ""

    out_path = PRE_PROC_DATA_ROOT + "object_detection_feat/"

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    mode = "inference"
    config = RankModelConfig()
    log_path = "logs/"
    ...

然后是获得一个 FPN 网络对象实例,也就是特征金字塔网络的对象实例。并加载权重

if __name__ == '__main__':
    ...
    model = FPN(mode=mode, config=config, model_dir=log_path)

    # Load weights
    print("Loading weights ", weight_path)
    model.load_weights(weight_path, by_name=True)
    ...

补充一下特征金字塔网络的输入和输出如下:

inputs = [input_image, input_image_meta, input_anchors]
outputs = [detections, feat_pyr_net_class, feat_pyr_net_bbox,
                           rpn_rois, rpn_class, rpn_bbox, P2, P3, P4, P5]

然后回到 object_detection.py,继续。开头就给 mode 赋值了,所以必然进入分支。
进入分支后,首先通过 dataset 获得 image_ids ,然后通过 image_ids 获得图片文件。其中 dataset.load_image(image_id) 方法返回的是 [H,W,3] 的 Numpy 数组。

if __name__ == '__main__':
    ...
    if mode == "inference":
        # Test Dataset
        dataset = Dataset(DATASET_ROOT, "test")

        predictions = []

        num = len(dataset.img_ids)
        for i in range(num):

            image_id = dataset.img_ids[i]
            print("\n", i + 1, " / ", num, " - ", image_id)

            image = dataset.load_image(image_id)
            ...

获得图片数据之后,将图片数据传入 model.detect 方法,这个方法返回的是一个 list,list中的元素是一个个的字典,一个字典对应一张图片。(但是在这个场景下,图片是一张一张通过 dataset 取出来,然后再一张一张传进 detect )。

if __name__ == '__main__':
    ...
    if mode == "inference":
        ...
        for i in range(num):
			...
			result = model.detect([image], verbose=1)
            pr = result[0]
            ...

model.detect 具体返回情况:

"""Runs the detection pipeline.
images: List of images, potentially of different sizes.

Returns a list of dicts, one dict per image. The dict contains:
rois: [N, (y1, x1, y2, x2)] detection bounding boxes
class_ids: [N] int class IDs
scores: [N] float probability scores for the class IDs
"""
...
results.append({
                "rois": final_rois,
                "class_ids": final_class_ids,
                "scores": final_scores,
                "P2": final_P2,
                "P3": final_P3,
                "P4": final_P4,
                "P5": final_P5,
            })
return results

取返回结果里面的 roisclass_idsscores ,再加上其余的信息,也就是 image_id ,一起通过 pickle 存入本地文件。(关于 pickle 在 training 代码解析中已经说明过了)

if __name__ == '__main__':
    ...
    if mode == "inference":
        ...
        for i in range(num):
			...
			rois = pr["rois"].tolist()
            class_ids = pr["class_ids"].tolist()
            scores = pr["scores"].tolist()

            res = {}
            res["image_id"] = image_id
            res["rois"] = rois
            res["class_ids"] = class_ids
            res["scores"] = scores
            predictions.append(res)
		# 同样,把检测的结果存入本地文件
        o_p = out_path + "object_detection_test_images.pkl"
        with open(o_p, "wb") as f:
            pickle.dump(predictions, f, pickle.HIGHEST_PROTOCOL)

至此,inference 的第一部分的代码就解析完毕。

总结来说,第一部分的任务就是检测显著性物体,得到图片的 roisclass_idsscoresimage_id 这四个信息,存进本地。

第二部分 获取图片的 obj_masksobj_featP5 信息

第二部分涉及的内容

根据 README.md,接下来需要运行 pre_process/pre_process_obj_feat.py
来看看这部分源码:

一模一样的结构套路啊,首先设置好各种路径。

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your location

if __name__ == '__main__':
    # add pre-trained weight path - backbone pre-trained on salient objects (binary, no rank)
    weight_path = ""

    data_split = "test"

    out_path = PRE_PROC_DATA_ROOT + "object_detection_feat/" + data_split + "/"

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    mode = "inference"
    config = RankModelConfig()
    log_path = "logs/"
    ...

然后先获得一个 Module 的实例对象,存为 keras_model,然后再获得一个这个模型的包装类对象(存为 model ),这个包装类里还封装了方便训练和检测 keras_model 的方法,包括载入权重、训练、检测等等。这个流(套)程(路)在 training 的三个部分中都使用了,在解析 training代码的文章中已经写过。

if __name__ == '__main__':
	...
    keras_model = Model_Obj_Feat.build_obj_feat_model(config)
    model_name = "Obj_Feat_Net"

    model = PreProcNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)

    # Load weights
    print("Loading weights ", weight_path)
    model.load_weights(weight_path, by_name=True)
    ...

同样,必然进入分支。进入分支后也是先获得 dataset,然后获得 img_ids 。随后进入循环,对 img_ids 中每一个 id 做处理:

  1. 调用 DataGenerator.load_inference_data_obj_feat 方法,结果存为 input_data 。这个方法的具体返回结果(也就是 input_data 的具体组成)为:
[image, image_meta, batch_obj_roi]
  1. input_data 传入 detect 方法获得检测结果。(detect 的具体结果也放在后文)
...
detections = self.keras_model.predict(input_data, verbose=0)

# Process detection
obj_masks, obj_feat, P5 = detections
result = {}
result["obj_masks"] = obj_masks
result["obj_feat"] = obj_feat
result["P5"] = P5

return result
  1. 将得到的结果通过 pickle 存入本地文件。
if __name__ == '__main__':
	...
    if mode == "inference":
        # ********** Create Datasets
        obj_det_path = PRE_PROC_DATA_ROOT + "object_detection_feat/"

        # Test Dataset
        dataset = Dataset(DATASET_ROOT, "test", obj_det_path)

        predictions = []

        num = len(dataset.img_ids)
        for i in range(num):

            image_id = dataset.img_ids[i]
            print(i + 1, " / ", num, " - ", image_id)

            input_data = DataGenerator.load_inference_data_obj_feat(dataset, image_id, config)

            result = model.detect(input_data, verbose=1)

            o_p = out_path + image_id
            with open(o_p, "wb") as f:
                pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)

第二部分也到此结束,总结来说,在这部分中,每张图片的 obj_masksobj_featP5 信息被存入本地。

第三部分 得到显著性排序结果

第三部分涉及的内容
根据 README.md,最后运行 evaluation/predict.py
来看看这部分源码:

啊完全是一模一样的流程,就不拆开讲了。

设置路径 → 获取  和它的包装类对象  → 获得 dataset 和 img_ids → 进入循环单独对每个图片 id 做处理 → 通过 dataGenerator 获得 input_data →将 input_data 传进 detect 方法获得检测结果 → 最后存入本地

DATASET_ROOT = "D:/Desktop/ASSR/"   # Change to your location
PRE_PROC_DATA_ROOT = "D:/Desktop/ASSR_Data/"    # Change to your location

if __name__ == '__main__':
    weight_path = "../weights/" + ".h5"

    out_path = "../predictions/"

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    config = InferenceConfig()
    log_path = "logs/"
    mode = "inference"

    keras_model = Model_SAM_SMM.build_saliency_rank_model(config, mode)
    model_name = "Rank_Model_SAM_SMM"
    model = ASRNet(mode=mode, config=config, model_dir=log_path, keras_model=keras_model, model_name=model_name)

    # Load weights
    print("Loading weights ", weight_path)
    model.load_weights(weight_path, by_name=True)

    # ********** Create Datasets
    dataset = DatasetTest(DATASET_ROOT, PRE_PROC_DATA_ROOT, "test")

    # **************************************************
    print("Start Prediction...")

    predictions = []

    num = len(dataset.img_ids)
    for i in range(num):

        image_id = dataset.img_ids[i]
        print(i + 1, " / ", num, " - ", image_id)

        input_data = DataGeneratorTest.load_inference_data(dataset, image_id, config)

        result = model.detect(input_data, verbose=1)

        o_p = out_path + image_id
        with open(o_p, "wb") as f:
            pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)

其中 DataGeneratorTest 返回的结果:

[obj_feat, batch_obj_spatial_masks, p5_feat]

keras_model 的输入和输出:

# *********************** FINAL ***********************
# Model
inputs = [input_obj_features, input_obj_spatial_masks,
          input_P5_feat]
outputs = [object_rank]
model = Model(inputs=inputs, outputs=outputs, name="attn_shift_saliency_rank_model")

也就是说最后 detect 得到的结果是 object_rank
到这里就完成了所有的 inference 过程,得到了我们要预测的东西,也就是显著性物体等级的排序结果。

三合一:三部分的关联

诶,那这三部分的联系在哪儿呢?没说到啊?

诀窍就在 每个dataset 和 dataGenerator 中。

第三部分调用的 dataGenerator 来源于这个文件:evaluation/DataGeneratorTest.py ,在这个文件中只定义了一个方法,源码如下:

def load_inference_data(dataset, image_id, config):

    # Load GT data
    pre_proc_data = dataset.load_obj_pre_proc_data(image_id)

    obj_feat = pre_proc_data["obj_feat"]
    p5_feat = pre_proc_data["P5"]

    # Load Object Spatial Mask
    object_roi_masks = dataset.load_object_roi_masks(image_id)

    # For 32 x 32 image size
    scale = 0.05
    padding = [(4, 4), (0, 0), (0, 0)]
    crop = None
    obj_spatial_masks = utils.resize_mask(object_roi_masks, scale, padding, crop)

    # Transpose and add dimension
    # [32, 32, N] -> [N, 32, 32, 1]
    obj_spatial_masks = np.expand_dims(np.transpose(obj_spatial_masks, [2, 0, 1]), -1)

    # fill rest with 0
    batch_obj_spatial_masks = np.zeros(shape=(config.SAL_OBJ_NUM, 32, 32, 1), dtype=np.float32)

    batch_obj_spatial_masks[:obj_spatial_masks.shape[0]] = obj_spatial_masks

    batch_obj_spatial_masks = np.expand_dims(batch_obj_spatial_masks, axis=0)

    return [obj_feat, batch_obj_spatial_masks, p5_feat]

在这段代码中完成了:

  • 加载 ground truth 数据:调用 dataset.load_obj_pre_proc_data(image_id)。ground truth 数据中包括 obj_featP5
  • 加载 mask 数据:调用 dataset.load_object_roi_masks(image_id)。得到 obj_roi_mask 之后又进行了一系列处理,包括 resize、修改 shape等,得到了 batch_obj_spatial_masks

最后返回的就是这三个数据。

那么关键问题就在于:dataset 具体是怎么把数据 load 进来的呢? dataset 是这个类的实例:evaluation/Dataset.Dataset ,我们具体来看看 load_obj_pre_proc_data 和 load_object_roi_masks 的实现过程:

load_obj_pre_proc_data 就是直接读取了本地文件。这个文件就是第二部分存进去的。

def load_obj_pre_proc_data(self, image_id):
    p = self. pre_proc_data_dir_root + "object_detection_feat/" + self.data_split + "/" + image_id

    with open(p, "rb") as f:
        obj_data = pickle.load(f)

    return obj_data

load_object_roi_masks 主要任务是根据 rois 坐标生成掩码图,逻辑很简单,obj_mask[o_y1:o_y2, o_x1:o_x2] = 1 ← 这句话是核心。

def load_object_roi_masks(self, image_id):
    image = self.load_image(image_id)
    idx = self.img_ids.index(image_id)

    rois = self.rois[idx]

    if len(rois) < 1:
        obj_masks = np.empty([0, 0, 0])
        return obj_masks

    # Reference Mask
    image_shape = image.shape[:2]
    init_mask = np.zeros(shape=image_shape, dtype=np.int32)

    # Generate list of object masks from salient and randomly selected non salient objects if available
    obj_mask_instances = []
    for i in range(len(rois)):
        obj = rois[i]
        # o_x1, o_y1, o_x2, o_y2 = obj
        o_y1, o_x1, o_y2, o_x2 = obj    # original coco format

        obj_mask = init_mask.copy()
        obj_mask[o_y1:o_y2, o_x1:o_x2] = 1
        obj_mask_instances.append(obj_mask)

    obj_masks = np.stack(obj_mask_instances, axis=2).astype(np.bool)

    return obj_masks

我们关心的是,上面代码中的第三句,rois 的坐标哪儿来的?也就是怎么得到 self.rois ?
跳转发现,是在该类的另一个方法 load_obj_data 中得到的。来看这个方法的源码:
容易发现,rois = d[“rois”],而 d 就是 data 中的元素,data 呢?data 是通过 pickle 从本地文件中读进来的。这个本地文件是第一部分存进去的。

def load_obj_data(self):
    data_path = self.pre_proc_data_dir_root + "object_detection_feat/object_detection_test_images.pkl"

    with open(data_path, "rb") as f:
        data = pickle.load(f)

    self.image_id = []
    self.rois = []
    # self.class_ids = []
    # self.scores = []

    for i in range(len(data)):
        d = data[i]
        image_id = d["image_id"]
        rois = d["rois"]
        # class_ids = d["class_ids"]
        scores = d["scores"]

        num_good_objects = sum(s > OBJ_THRESH for s in scores)

        keep_point = num_good_objects
        if keep_point > MAX_OBJ_NUM:
            keep_point = MAX_OBJ_NUM

        self.image_id.append(image_id)
        self.rois.append(rois[:keep_point])
        # self.class_ids.append(class_ids[:keep_point])
        # self.scores.append(scores[:keep_point])

    assert self.img_ids == self.image_id

所以说,第三部分的keras_model 的输入中:

inputs = [input_obj_features, input_obj_spatial_masks,
          input_P5_feat]

input_obj_features 和 input_P5_feat 是第二部分预测的结果存进去的,input_obj_spatial_masks 来源于第一部分的预测结果存进去的数据。这就是三个部分的联系。

以上,inference 部分的代码也解读完毕。欢迎评论区讨论。(如果有人看的话…)

另外补充一点:那第一部分和第二部分存在本地的别的信息,比如 obj_masks 这种,根本没用到?确实没在inference中用到。但是最后 visualize 中,画出来好看的图是需要用到这些信息的。

后续将更新 loss 部分(应该会很短)。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值