论文阅读 | Deep Incomplete Multi-view Clustering with Cross-view Partial Sample and Prototype Alignment

文章介绍

在这里插入图片描述
本文是出自刘新旺团队的一篇文章,23年被CVPR收录。
附上源码:github-CPSPAN

文章主要贡献:

  • We propose a novel deep network to handle IMVC task, termed as CPSPAN. Differ from existing multiview contrastive learning manner, we considers the IMVC from a novel insight with partially-aligned setting. To this end, CPSPAN optimal maximizes matching alignment between paired-observed data and construct cross-view intersection.(CPSPAN 最佳地最大化成对观察数据之间的匹配对齐度)
  • Inorder to solve the Prototype-Shifted Problem caused by incomplete information, CPSPAN proposes to further align the prototype sets between different views,so as to mine consistent cross-view structural information.(对齐跨试图的原型集,便于挖掘跨试图的结构信息)
  • Extensive experiments have clearly demonstrated the effectiveness of the proposed cross-view partial sample and prototype alignment modules and the superiority over both conventional and deep SOTA methods.(实验表现为SOTA级别)

文章主要框架:在这里插入图片描述

CPSPAN 模型:

该模型包含三个联合学习的模块:

  • 不完全多视图表示学习模块: 使用深度自编码器学习每个视图的聚类友好特征,避免因缺失数据导致的实例错位问题。
  • 跨视图部分样本对齐模块 (CPSA): 通过成对观测数据建立跨视图样本对应关系,学习更灵活的表示,并挖掘更好的数据结构信息。
  • 偏移原型对齐模块 (SPA): 通过匹配原型集之间的对应关系,校准跨视图样本之间的关系和跨视图原型之间的关系,解决原型偏移问题,进一步提高聚类性能。
主要的损失函数:

在这里插入图片描述
分别是重构损失(L_rec)、实例对齐损失(L_instance_alignment)、原型对齐损失(L_prototype-alignment)、α,β为温度参数。

1、重构损失
在这里插入图片描述
2、实例对齐损失(Instance_alignment_loss)
在这里插入图片描述

3、原型对齐损失(prototype_alignment_loss)
在这里插入图片描述

结构嵌入填补策略(对不完全试图缺失的数据进行填充)

在这里插入图片描述
基于前两步得到的相似度矩阵。对于视图1中的缺失特征,找到视图2中与其嵌入最接近的邻居,然后直接用视图2中那个邻居的嵌入来填充视图1中的缺失特征。

实验性能

在这里插入图片描述

聚类指标:
  • ACC
  • NMI
  • F-mea
数据集:
  • Caltech101-7
  • HandWritten
  • ALOI-100
  • YouTubeFace10
  • EMNIST

聚类可视化

在这里插入图片描述

源码解析

初始化和训练过程:

      # 模型初始化和训练
        model = Network(args.V, args.view_dims, args.feature_dim).to(device)      ### 建立model
        optimizer_pretrain = torch.optim.Adam(model.parameters(), lr=args.lr_pre) ### 预训练
        fea_emb = pretrain(model, optimizer_pretrain, args, device, X_com, Y_com) ### 特征提取
        optimizer_align = torch.optim.Adam(model.parameters(), lr=args.lr_align)  ### 优化对齐
        fea_end = train_align(model, optimizer_align, args, device, X, Y, Miss_vecs)  ### 训练对齐

预训练模块:

def pretrain(model, opt_pre, args, device, X_com, Y_com):
    train_dataset = TrainDataset_Com(X_com, Y_com)
    batch_sampler = Data_Sampler(train_dataset, shuffle=True, batch_size=args.batch_size, drop_last=False)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_sampler=batch_sampler)

    t_progress = tqdm(range(args.pretrain_epochs), desc='Pretraining')
    for epoch in t_progress:
        tot_loss = 0.0
        loss_fn = torch.nn.MSELoss()
        for batch_idx, (xs, ys) in enumerate(train_loader):
            for v in range(args.V):
                xs[v] = torch.squeeze(xs[v]).to(device)
            opt_pre.zero_grad()
            zs, xrs = model(xs)
            loss_list = []
            for v in range(args.V):
                loss_value = loss_fn(xs[v], xrs[v])
                loss_list.append(loss_value)
            loss = sum(loss_list)
            loss.backward()
            opt_pre.step()
            tot_loss += loss.item()
        print('Epoch {}'.format(epoch + 1), 'Loss:{:.6f}'.format(tot_loss / len(train_loader)))

    fea_emb = []
    for v in range(args.V):
        fea_emb.append([])

    all_dataset = TrainDataset_Com(X, Y)
    batch_sampler_all = Data_Sampler(all_dataset, shuffle=False, batch_size=args.batch_size, drop_last=False)
    all_loader = torch.utils.data.DataLoader(dataset=all_dataset, batch_sampler=batch_sampler_all)

    with torch.no_grad():
        for batch_idx2, (xs2, _) in enumerate(all_loader):
            for v in range(args.V):
                xs2[v] = torch.squeeze(xs2[v]).to(device)
            zs2, xrs2 = model(xs2)
            for v in range(args.V):
                zs2[v] = zs2[v].cpu()
                fea_emb[v] = fea_emb[v] + zs2[v].tolist()

    for v in range(args.V):
        fea_emb[v] = torch.tensor(fea_emb[v])

    return fea_emb

训练对齐模块:

# 训练对齐模型
def train_align(model, opt_align, args, device, X, Y, Miss_vecs):
    train_dataset = TrainDataset_All(X, Y, Miss_vecs)
    batch_sampler = Data_Sampler(train_dataset, shuffle=True, batch_size=args.Batch_Align, drop_last=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_sampler=batch_sampler)

    t_progress = tqdm(range(args.align_epochs), desc='Alignment')
    for epoch in t_progress:
        for batch_idx, (x, y, miss_vec) in enumerate(train_loader):
            opt_align.zero_grad()
            ###### 计算loss_recon 重构损失######
            loss_fn = torch.nn.MSELoss().to(device)
            loss_list_recon = []
            for v in range(args.V):
                x[v] = torch.squeeze(x[v]).to(device)
                y[v] = torch.squeeze(y[v]).to(device)
                miss_vec[v] = torch.squeeze(miss_vec[v]).to(device)
            z, xr = model(x)
            for v in range(args.V):
                loss_list_recon.append(loss_fn(x[v][miss_vec[v]>0], xr[v][miss_vec[v]>0]))
            loss_recon = sum(loss_list_recon)

            ###### 计算loss_ins_align  计算实例对齐损失######
            criterion_ins = Instance_Align_Loss().to(device)
            loss_list_ins = []
            for v1 in range(args.V):
                v2_start = v1 + 1
                for v2 in range(v2_start, args.V):
                    align_index = []
                    for i in range(x[0].shape[0]):
                        if miss_vec[v1][i] == 1 and miss_vec[v2][i] == 1:
                            align_index.append(i)

                    z1 = z[v1][align_index]  
                    z2 = z[v2][align_index]  
                    Dx = F.cosine_similarity(z1, z2, dim=1)
                    gt = torch.ones(z1.shape[0]).to(device)
                    l_tmp2 = criterion_ins(gt, Dx)
                    loss_list_ins.append(l_tmp2)
            loss_ins_align = sum(loss_list_ins)
            ###### 计算Proto_Align_loss 计算原型对其损失######
            criterion_proto = Proto_Align_Loss().to(device)
            loss_list_pro = []
            for v1 in range(args.V):
                v2_start = v1 + 1
                for v2 in range(v2_start, args.V):
                    align_index = []
                    for i in range(z[0].shape[0]):
                        if miss_vec[v1][i] == 1 and miss_vec[v2][i] == 1:
                            align_index.append(i)

                    p1 = z[v1][align_index].t()
                    p2 = z[v2][align_index].t()
                    gt = torch.ones(p1.shape[0]).to(device)
                    Dp = get_Similarity(p1, p2)
                    l_tmp = criterion_proto(gt, Dp)
                    loss_list_pro.append(l_tmp)
            loss_pro_align = sum(loss_list_pro)
            loss_total = loss_recon + para_loss[0] * loss_pro_align + para_loss[1] * loss_ins_align  # 总损失函数的计算

            loss_total.backward()
            opt_align.step()

    fea_all = []
    for v in range(args.V):
        fea_all.append([])

    all_dataset = TrainDataset_Com(X, Y)
    batch_sampler_all = Data_Sampler(all_dataset, shuffle=False, batch_size=args.batch_size, drop_last=False)
    all_loader = torch.utils.data.DataLoader(dataset=all_dataset, batch_sampler=batch_sampler_all)

    with torch.no_grad():
        for batch_idx2, (xs2, _) in enumerate(all_loader):
            for v in range(args.V):
                xs2[v] = torch.squeeze(xs2[v]).to(device)
            zs2, xrs2 = model(xs2)
            for v in range(args.V):
                zs2[v] = zs2[v].cpu()
                fea_all[v] = fea_all[v] + zs2[v].tolist()

    for v in range(args.V):
        fea_all[v] = torch.tensor(fea_all[v])

    return fea_all

总结:预训练这一步提取了特征fea_emb,训练对齐这一块得到了三个损失。最终填补数据,再进行一个K-means聚类。

总结

CPSPAN由三个模块组成:不完全多视图表示学习模块、跨视图部分样本对齐模块(CPSA)和移位原型对齐模块(SPA)。具体来说,CPSA执行不同视图之间的实例对齐。SPA探索原型之间的最佳匹配对应。然后,使用结构嵌入填补策略来填补缺失的嵌入。最终,将完整的嵌入和填充的嵌入连接起来,然后用K-means聚类算法以获得最终结果。

  • 27
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值