文章介绍
本文是出自刘新旺团队的一篇文章,23年被CVPR收录。
附上源码:github-CPSPAN
文章主要贡献:
- We propose a novel deep network to handle IMVC task, termed as CPSPAN. Differ from existing multiview contrastive learning manner, we considers the IMVC from a novel insight with partially-aligned setting. To this end, CPSPAN optimal maximizes matching alignment between paired-observed data and construct cross-view intersection.(CPSPAN 最佳地最大化成对观察数据之间的匹配对齐度)
- Inorder to solve the Prototype-Shifted Problem caused by incomplete information, CPSPAN proposes to further align the prototype sets between different views,so as to mine consistent cross-view structural information.(对齐跨试图的原型集,便于挖掘跨试图的结构信息)
- Extensive experiments have clearly demonstrated the effectiveness of the proposed cross-view partial sample and prototype alignment modules and the superiority over both conventional and deep SOTA methods.(实验表现为SOTA级别)
文章主要框架:
CPSPAN 模型:
该模型包含三个联合学习的模块:
- 不完全多视图表示学习模块: 使用深度自编码器学习每个视图的聚类友好特征,避免因缺失数据导致的实例错位问题。
- 跨视图部分样本对齐模块 (CPSA): 通过成对观测数据建立跨视图样本对应关系,学习更灵活的表示,并挖掘更好的数据结构信息。
- 偏移原型对齐模块 (SPA): 通过匹配原型集之间的对应关系,校准跨视图样本之间的关系和跨视图原型之间的关系,解决原型偏移问题,进一步提高聚类性能。
主要的损失函数:
分别是重构损失(L_rec)、实例对齐损失(L_instance_alignment)、原型对齐损失(L_prototype-alignment)、α,β为温度参数。
1、重构损失
2、实例对齐损失(Instance_alignment_loss)
3、原型对齐损失(prototype_alignment_loss)
结构嵌入填补策略(对不完全试图缺失的数据进行填充)
基于前两步得到的相似度矩阵。对于视图1中的缺失特征,找到视图2中与其嵌入最接近的邻居,然后直接用视图2中那个邻居的嵌入来填充视图1中的缺失特征。
实验性能
聚类指标:
- ACC
- NMI
- F-mea
数据集:
- Caltech101-7
- HandWritten
- ALOI-100
- YouTubeFace10
- EMNIST
聚类可视化
源码解析
初始化和训练过程:
# 模型初始化和训练
model = Network(args.V, args.view_dims, args.feature_dim).to(device) ### 建立model
optimizer_pretrain = torch.optim.Adam(model.parameters(), lr=args.lr_pre) ### 预训练
fea_emb = pretrain(model, optimizer_pretrain, args, device, X_com, Y_com) ### 特征提取
optimizer_align = torch.optim.Adam(model.parameters(), lr=args.lr_align) ### 优化对齐
fea_end = train_align(model, optimizer_align, args, device, X, Y, Miss_vecs) ### 训练对齐
预训练模块:
def pretrain(model, opt_pre, args, device, X_com, Y_com):
train_dataset = TrainDataset_Com(X_com, Y_com)
batch_sampler = Data_Sampler(train_dataset, shuffle=True, batch_size=args.batch_size, drop_last=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_sampler=batch_sampler)
t_progress = tqdm(range(args.pretrain_epochs), desc='Pretraining')
for epoch in t_progress:
tot_loss = 0.0
loss_fn = torch.nn.MSELoss()
for batch_idx, (xs, ys) in enumerate(train_loader):
for v in range(args.V):
xs[v] = torch.squeeze(xs[v]).to(device)
opt_pre.zero_grad()
zs, xrs = model(xs)
loss_list = []
for v in range(args.V):
loss_value = loss_fn(xs[v], xrs[v])
loss_list.append(loss_value)
loss = sum(loss_list)
loss.backward()
opt_pre.step()
tot_loss += loss.item()
print('Epoch {}'.format(epoch + 1), 'Loss:{:.6f}'.format(tot_loss / len(train_loader)))
fea_emb = []
for v in range(args.V):
fea_emb.append([])
all_dataset = TrainDataset_Com(X, Y)
batch_sampler_all = Data_Sampler(all_dataset, shuffle=False, batch_size=args.batch_size, drop_last=False)
all_loader = torch.utils.data.DataLoader(dataset=all_dataset, batch_sampler=batch_sampler_all)
with torch.no_grad():
for batch_idx2, (xs2, _) in enumerate(all_loader):
for v in range(args.V):
xs2[v] = torch.squeeze(xs2[v]).to(device)
zs2, xrs2 = model(xs2)
for v in range(args.V):
zs2[v] = zs2[v].cpu()
fea_emb[v] = fea_emb[v] + zs2[v].tolist()
for v in range(args.V):
fea_emb[v] = torch.tensor(fea_emb[v])
return fea_emb
训练对齐模块:
# 训练对齐模型
def train_align(model, opt_align, args, device, X, Y, Miss_vecs):
train_dataset = TrainDataset_All(X, Y, Miss_vecs)
batch_sampler = Data_Sampler(train_dataset, shuffle=True, batch_size=args.Batch_Align, drop_last=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_sampler=batch_sampler)
t_progress = tqdm(range(args.align_epochs), desc='Alignment')
for epoch in t_progress:
for batch_idx, (x, y, miss_vec) in enumerate(train_loader):
opt_align.zero_grad()
###### 计算loss_recon 重构损失######
loss_fn = torch.nn.MSELoss().to(device)
loss_list_recon = []
for v in range(args.V):
x[v] = torch.squeeze(x[v]).to(device)
y[v] = torch.squeeze(y[v]).to(device)
miss_vec[v] = torch.squeeze(miss_vec[v]).to(device)
z, xr = model(x)
for v in range(args.V):
loss_list_recon.append(loss_fn(x[v][miss_vec[v]>0], xr[v][miss_vec[v]>0]))
loss_recon = sum(loss_list_recon)
###### 计算loss_ins_align 计算实例对齐损失######
criterion_ins = Instance_Align_Loss().to(device)
loss_list_ins = []
for v1 in range(args.V):
v2_start = v1 + 1
for v2 in range(v2_start, args.V):
align_index = []
for i in range(x[0].shape[0]):
if miss_vec[v1][i] == 1 and miss_vec[v2][i] == 1:
align_index.append(i)
z1 = z[v1][align_index]
z2 = z[v2][align_index]
Dx = F.cosine_similarity(z1, z2, dim=1)
gt = torch.ones(z1.shape[0]).to(device)
l_tmp2 = criterion_ins(gt, Dx)
loss_list_ins.append(l_tmp2)
loss_ins_align = sum(loss_list_ins)
###### 计算Proto_Align_loss 计算原型对其损失######
criterion_proto = Proto_Align_Loss().to(device)
loss_list_pro = []
for v1 in range(args.V):
v2_start = v1 + 1
for v2 in range(v2_start, args.V):
align_index = []
for i in range(z[0].shape[0]):
if miss_vec[v1][i] == 1 and miss_vec[v2][i] == 1:
align_index.append(i)
p1 = z[v1][align_index].t()
p2 = z[v2][align_index].t()
gt = torch.ones(p1.shape[0]).to(device)
Dp = get_Similarity(p1, p2)
l_tmp = criterion_proto(gt, Dp)
loss_list_pro.append(l_tmp)
loss_pro_align = sum(loss_list_pro)
loss_total = loss_recon + para_loss[0] * loss_pro_align + para_loss[1] * loss_ins_align # 总损失函数的计算
loss_total.backward()
opt_align.step()
fea_all = []
for v in range(args.V):
fea_all.append([])
all_dataset = TrainDataset_Com(X, Y)
batch_sampler_all = Data_Sampler(all_dataset, shuffle=False, batch_size=args.batch_size, drop_last=False)
all_loader = torch.utils.data.DataLoader(dataset=all_dataset, batch_sampler=batch_sampler_all)
with torch.no_grad():
for batch_idx2, (xs2, _) in enumerate(all_loader):
for v in range(args.V):
xs2[v] = torch.squeeze(xs2[v]).to(device)
zs2, xrs2 = model(xs2)
for v in range(args.V):
zs2[v] = zs2[v].cpu()
fea_all[v] = fea_all[v] + zs2[v].tolist()
for v in range(args.V):
fea_all[v] = torch.tensor(fea_all[v])
return fea_all
总结:预训练这一步提取了特征fea_emb,训练对齐这一块得到了三个损失。最终填补数据,再进行一个K-means聚类。
总结
CPSPAN由三个模块组成:不完全多视图表示学习模块、跨视图部分样本对齐模块(CPSA)和移位原型对齐模块(SPA)。具体来说,CPSA执行不同视图之间的实例对齐。SPA探索原型之间的最佳匹配对应。然后,使用结构嵌入填补策略来填补缺失的嵌入。最终,将完整的嵌入和填充的嵌入连接起来,然后用K-means聚类算法以获得最终结果。