coco再分组与网络按照分组进行训练

背景:coco有原始的分组,我们通过谱聚类进行了新的分组。需要对coco进行再分组。然后送入网络训练。

目录

一、分组写入文件

1.1 写入参考

1.2 分组结果写入

1.3 分组变量的读出

二、程序中coco分组的运用

2.1 程序调用关系

2.2 聚类分组情况

2.3 更改Config

2.4 重新定义网络

2.5 直接在网络中加载

三、网络结构的定义

3.1 网络输出与标签

3.2 序列按照idx进行重排

3.3  运用index_select调序

3.4 序列之中的调换


一、分组写入文件

1.1 写入参考

可以运用pickle写入文件

参考其他程序中的写入:

    correlations = {}
    correlations.update(pp=A_B) #p(A/B)
    correlations.update(fp=notA_B) # P(not A/B)
    correlations.update(pf=A_notB)
    correlations.update(ff=notA_notB)
    with open('sk_spectral_cluster/coco_correlations.pkl', 'wb') as f:
        print("write correlations in sk_spectral_cluster/coco_correlations.pkl")
        pickle.dump(correlations, f)
    with open('sk_spectral_cluster/coco_names.pkl','wb') as name_file:
        print("write correlations in sk_spectral_cluster/coco_names.pkl")
        pickle.dump(names, name_file)

相当于直接将变量用pickle.dump写入文件f之中。

1.2 分组结果写入

分组之后,我们的结果写入split_groups之中,是一个字典。格式为下面注释中的格式。

    #---------------store the split result into .pkl file-------
    #  in format dict {   0:  [1, 2, 3, 4, 5, 7, 9, 10, 11, 12]
    #                     1:  [46, 47, 49, 50, 51]
    #                     .........
    #                     2:  [22, 23, 32, 34, 35, 38]   }

    with open('sk_spectral_cluster/coco_label_cluster_result.pkl', 'wb') as f:
        print("write cluster result into sk_spectral_cluster/coco_label_cluster_result.pkl")
        pickle.dump(split_groups, f)

命名为 coco_label_cluster_result.pkl

1.3 分组变量的读出

直接根据写入路径即可读出:

    with open('sk_spectral_cluster/coco_label_cluster_result.pkl', 'rb') as f:
        print("loading sk_spectral_cluster/coco_label_cluster_result.pkl ")
        split_groups= pickle.load(f)

    print("split_groups: ",split_groups)

 

二、程序中coco分组的运用

2.1 程序调用关系

general_train之中,调用COCO2014进行数据集的读取。

    train_dataset = COCO2014(args.data, phase='train', inp_name=Config.INP_NAME, is_grouping=True)  # fixme
    val_dataset = COCO2014(args.data, phase='val', inp_name=Config.INP_NAME, is_grouping=True)  # fixme

其中对组的定义通过config传入函数之中:

    GROUPS = 12
    NCLASSES = 80
    NCLASSES_PER_GROUP = [1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7]  # FIXME: to check
    GROUP_CHANNELS = 512
    CLASS_CHANNELS = 256

 直接将分组的参量传入GROUP之中,

    if Config.MODEL == 'hgat_fc':
        import mymodels.hgat_fc as hgat_fc
        model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
                                nclasses_per_group=Config.NCLASSES_PER_GROUP,
                                group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)

2.2 聚类分组情况

split groups:

group: 1   group element numbers 10   group_elements :   [1, 2, 3, 4, 5, 7, 9, 10, 11, 12]
group: 2   group element numbers 5   group_elements :   [46, 47, 49, 50, 51]
group: 3   group element numbers 6   group_elements :   [22, 23, 32, 34, 35, 38]
group: 4   group element numbers 4   group_elements :   [16, 18, 19, 29]
group: 5   group element numbers 5   group_elements :   [6, 24, 25, 26, 28]
group: 6   group element numbers 7   group_elements :   [15, 57, 59, 65, 73, 74, 77]
group: 7   group element numbers 4   group_elements :   [61, 71, 78, 79]
group: 8   group element numbers 7   group_elements :   [39, 58, 68, 69, 70, 72, 75]
group: 9   group element numbers 10   group_elements :   [0, 8, 13, 14, 17, 20, 21, 33, 36, 37]
group: 10   group element numbers 5   group_elements :   [62, 63, 64, 66, 67]
group: 11   group element numbers 10   group_elements :   [40, 41, 42, 43, 44, 45, 53, 55, 56, 60]
group: 12   group element numbers 2   group_elements :   [30, 31]
group: 13   group element numbers 2   group_elements :   [48, 52]
group: 14   group element numbers 2   group_elements :   [27, 54]
group: 15   group element numbers 1   group_elements :   [76]

Final results,group numbers:  15  max_classes_per_group:  10  probability filter threshold:  0.05
group: 1   group element numbers:  10
group_elements :  ['bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'truck', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter']
group: 2   group element numbers:  5
group_elements :  ['banana', 'apple', 'orange', 'broccoli', 'carrot']
group: 3   group element numbers:  6
group_elements :  ['zebra', 'giraffe', 'sports ball', 'baseball bat', 'baseball glove', 'tennis racket']
group: 4   group element numbers:  4
group_elements :  ['dog', 'sheep', 'cow', 'frisbee']
group: 5   group element numbers:  5
group_elements :  ['train', 'backpack', 'umbrella', 'handbag', 'suitcase']
group: 6   group element numbers:  7
group_elements :  ['cat', 'couch', 'bed', 'remote', 'book', 'clock', 'teddy bear']
group: 7   group element numbers:  4
group_elements :  ['toilet', 'sink', 'hair drier', 'toothbrush']
group: 8   group element numbers:  7
group_elements :  ['bottle', 'potted plant', 'microwave', 'oven', 'toaster', 'refrigerator', 'vase']
group: 9   group element numbers:  10
group_elements :  ['person', 'boat', 'bench', 'bird', 'horse', 'elephant', 'bear', 'kite', 'skateboard', 'surfboard']
group: 10   group element numbers:  5
group_elements :  ['tv', 'laptop', 'mouse', 'keyboard', 'cell phone']
group: 11   group element numbers:  10
group_elements :  ['wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'pizza', 'cake', 'chair', 'dining table']
group: 12   group element numbers:  2
group_elements :  ['skis', 'snowboard']
group: 13   group element numbers:  2
group_elements :  ['sandwich', 'hot dog']
group: 14   group element numbers:  2
group_elements :  ['tie', 'donut']
group: 15   group element numbers:  1
group_elements :  ['scissors']

2.3 更改Config

根据分组情况更改相应的程序:

    BACKBONE = 'resnet101'
    GROUPS = 15
    NCLASSES = 80
    NCLASSES_PER_GROUP = [10, 5, 6, 4, 5, 7, 4,7,10,5,10,2,2,2,1]  # FIXME: to check

2.4 重新定义网络

在网络之中加入结构:

    elif Config.MODEL=='clustered_hgat_fc':
        import momydels.clustered_hgat_fc as clustered_hgat_fc
        model=clustered_hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
                                nclasses_per_group=Config.NCLASSES_PER_GROUP,
                                group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)

同时在my_model文件夹之中加入文件

clustered_hgat_fc.py

2.5 直接在网络中加载

直接在网络之中加载分组情况,加载完之后,送入网络,便于分组。

        #fixme----------- load clustered results
        # load groups and group classes
        with open('sk_spectral_cluster/coco_label_cluster_result.pkl', 'rb') as f:
            print("loading sk_spectral_cluster/coco_label_cluster_result.pkl ")
            split_groups = pickle.load(f)
        for key in split_groups:
            print("group:", key, "  group element numbers", len(split_groups[key]), "  group_elements :  ", split_groups[key])
        print("groups=len(split_groups) :", len(split_groups))

        nclasses_per_group = []
        cls_idx_order=[]
        for idx in range(len(split_groups)):
            nclasses_per_group.append(len(split_groups[idx + 1]))
            cls_idx_order=cls_idx_order+split_groups[idx + 1]

        torch_cls_idx_order=torch.IntTensor(cls_idx_order)
        self.torch_cls_idx_order=torch_cls_idx_order
        # print("final idx order:",self.torch_cls_idx_order)

 

三、网络结构的定义

3.1 网络输出与标签

最终输出的为一个n_classes的变量,对于与每个一标签。

        x = torch.cat(outside, dim=1)  # [B,nclasses,C]
        x = torch.cat([self.fcs[i](x[:, i, :]) for i in range(self.nclasses)], dim=1)  # [B,nclasses]
        return x

 

3.2 序列按照idx进行重排

https://blog.csdn.net/qq_25037903/article/details/88651166

torch.index_select

官网地址:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/#torchindex_select

torch.index_select(input, dim, index, out=None) → Tensor

沿着指定维度对输入进行切片,取index中指定的相应项(index为一个LongTensor),然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。

注意: 返回的张量不与原始张量共享内存空间。

参数:

  • input (Tensor) – 输入张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 包含索引下标的一维张量
  • out (Tensor, optional) – 目标张量

理解为,index为目标张量out中的值再原始张量input中的位置。

例如a为

idx2为:

就是根据idx2中的元素值选出a中的位置的值,存入out之中。

经过

a.index_select(0,idx2)

恢复出来为:

3.3  运用index_select调序

下面的x为分组后的组。

        outside = []
        for i in range(self.groups):
            inside = []
            for j in range(self.nclasses_per_group[i]):
                inside.append(self.class_fcs[count](x[:, i, :]))  # [B,C]
                count += 1
            inside = torch.stack(inside, dim=1)  # [B,N,C]
            inside = self.gat2s[i](inside)  # [B,N,C]
            outside.append(inside)
        x = torch.cat(outside, dim=1)  # [B,nclasses,C]
        x = torch.cat([self.fcs[i](x[:, i, :]) for i in range(self.nclasses)], dim=1)  # [B,nclasses]

恢复到分组前,需要将x调换顺序。

这个self.torch_cls_idx_order为每组在原order的顺序。

        torch_cls_idx_order=torch.IntTensor(cls_idx_order)
        self.torch_cls_idx_order=torch_cls_idx_order

我们的目的是:group_classes在class_order中的顺序,恢复出原class_order

        cls=x
        for idx in range(0,self.nclasses):
            cls[idx]=x[self.cls_idx_order[idx]]
        x=cls
        index=self.order_cls_idx
        x=x.index_select(0,index)

此法总是报错。暂时不管

3.4 序列之中的调换

程序之中改为:

        #fixme change order from order in groups to class order
        x = torch.cat([self.fcs[i](x[:, self.idx_in_group_2_cls_idx[i], :]) for i in range(self.nclasses)], dim=1)  # [B,nclasses]

级运用x的顺序,恢复出class的顺序。x为分组后的变量顺序,为了恢复出原始的变量的顺序,我们需要一个对应。

split_group的对应为,下标idx对应每组输出的位置,值value对应每个类中的类别。

那么我们希望根据每组输出的恢复出每个类的标签。

我们现在,已知每个类别,需要找到在split_group中对应的下标,所以需要一个value到idx的对应。

        idx_in_group_2_cls_idx={}
        for idx in range(len(cls_idx_order)):
            idx_in_group_2_cls_idx[cls_idx_order[idx]]=idx
        print("idx_in_group_2_cls_idx",idx_in_group_2_cls_idx)
        self.idx_in_group_2_cls_idx=idx_in_group_2_cls_idx
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

祥瑞Coding

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值