Learning Attentive Pairwise Interaction for Fine-Grained Classification
2020 AAAI。网络结构倒是不复杂,但是这么大的batch size要怎么跑起来。
文章目录
摘要
动机:目前方法都是通过单张图像学习区分性表示,而人类可以通过比较图像对来有效地识别。
网络:注意力成对交互网络(API-Net),通过交互逐步识别成对的细粒度图像。
- 先学习一个共同的特征向量,捕获输入对中的语义差异
- 将该向量与各个向量比较,为每个输入图像生成门
- 端到端,分数排序正则化
代码:https://github.com/PeiqinZhuang/API-Net
1 引言
人类是在比较中识别细粒度对象的。
引入注意力成对交互网络API-Net,可以从一对细粒度图像中自适应地发现对比线索,并通过成对交互进行区分。
API-Net由三个子模块组成,即相互向量学习,门向量生成和成对交互。输入一对图像,先学习一个互矢量,以将输入对的对比线索概括为上下文。再将互向量与单个向量进行比较生成不同的门,可以从每个单个图像的角度突出显示语义差异。将这些门作为区分性注意力执行成对交互。每个图像可以生成两个增强的特征向量,分别从其自身的门矢量和该对中另一个图像的门矢量激活。通过端到端的训练方式和分数排名正则化。
即插即用。
2 API-Net
2.1 互矢量学习
两张图片分别经过主干网络生成 D D D维特征向量 x 1 , x 2 x_1,x_2 x1,x2,映射函数(多层感知机)学习一个 D D D维互矢量 x m = f m ( [ x 1 , x 2 ] ) x_m=f_m([x_1,x_2]) xm=fm([x1,x2])。由于 x m x_m xm是两个的自适应总结,通常包含特征通道,指示成对的高层次的对比线索。
2.2 门向量
将
x
m
x_m
xm作为指导,寻找每个
x
i
x_i
xi包含的对比线索,生成门:
g
i
=
s
i
g
m
o
i
d
(
x
m
⊙
x
i
)
,
i
∈
{
1
,
2
}
g_i=sigmoid(x_m\odot x_i),i\in \{1,2\}
gi=sigmoid(xm⊙xi),i∈{1,2}
g
i
g_i
gi成为有区别的注意力,以不同角度指出了每个
x
i
x_i
xi的语义差异。
2.3 成对交互
通过门向量进行成对交互:
x
1
s
e
l
f
=
x
1
+
x
1
⊙
g
1
x
2
o
t
h
e
r
=
x
2
+
x
2
⊙
g
2
x
1
s
e
l
f
=
x
1
+
x
1
⊙
g
2
x
2
o
t
h
e
r
=
x
2
+
x
2
⊙
g
1
x^{self}_1=x_1+x_1\odot g_1\\ x^{other}_2=x_2+x_2\odot g_2\\ x^{self}_1=x_1+x_1\odot g_2\\ x^{other}_2=x_2+x_2\odot g_1\\
x1self=x1+x1⊙g1x2other=x2+x2⊙g2x1self=x1+x1⊙g2x2other=x2+x2⊙g1
x
i
s
e
l
f
x^{self}_i
xiself由自己的门向量激活,
x
i
o
t
h
e
r
x^{other}_i
xiother由另一个图像的门向量激活。
2.4 训练与测试
特征向量经过softmax分类器,得到 p i j p^j_i pij( j ∈ { s e l f , o t h e r } , i ∈ { 1 , 2 } j\in\{self,other\},i\in \{1,2\} j∈{self,other},i∈{1,2})。
损失:
L
=
L
c
e
+
λ
L
r
k
L=L_{ce}+\lambda L_{rk}
L=Lce+λLrk。
L
r
k
=
∑
i
∈
{
1
,
2
}
max
(
0
,
p
i
o
t
h
e
r
(
c
i
)
−
p
i
s
e
l
f
(
c
i
)
+
ϵ
)
L_{rk}=\sum_{i\in\{1,2\}}\max(0,p^{other}_i(c_i)-p^{self}_i(c_i)+\epsilon)
Lrk=i∈{1,2}∑max(0,piother(ci)−piself(ci)+ϵ)
由自己的门激活得到的结果应当更有区别性。
分批随机抽样 N c l N_{cl} Ncl类,每类随机抽取 N i m N_{im} Nim训练图像,生成其特征向量。对于每个图像,根据欧式距离将其特征与其他特征比较。结果可以为每个图像构造两对:类内最像对、类间最像对。每批共 2 × N c l × N i m 2\times N_{cl}\times N_{im} 2×Ncl×Nim对。
测试时:特征向量直接经过全连接分类。
3 实验
backbone:Resnet101。每个批次中随机抽取30个类别,每类随机采样4张图像,有240个图像对。
3.1 消融实验
基线模型:
互向量:
- 不采用互向量,各自生成门向量
- 双线性池化操作
- 逐元素操作,包括平方差、和、点积三种
- 权重注意力,两层MLP生成两个向量的权重
- MLP
门向量:
- 一个门: g m = s i g m o i d ( x m ) g_m=sigmoid(x_m) gm=sigmoid(xm),一种注意力 x i s e l f = x i + x i ⊙ g m x^{self}_i=x_i+x_i\odot g_m xiself=xi+xi⊙gm
- 两个门
交互:
-
仅使用交叉熵损失
-
交叉熵+排名损失
图像对的构建:
- 随机对
- 类别对
S表示最相似、D表示最不相似
批次的样本数:
3.2 比较SOTA
3.3 可视化
根据门向量得到top-5激活通道,再全局池化前进行可视化。以及resnet101的对应通道。
即使API-Net主要在高级特征上运行,也能自动关注特征图中的可区分对象部分。
4 源码阅读
模型:
def pdist(vectors):
"""
计算欧氏距离:-2(v1+v2) + v1^2 + v2^2
vectors: b*c,b个c维度向量
"""
# vectors.mm(torch.t(vectors)) v1*v2 ,b*b
# vectors.pow(2).sum(dim=1).view(1, -1) v1^2 ,1*b
# vectors.pow(2).sum(dim=1).view(-1, 1) v2^2 ,b*1
distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
dim=1).view(-1, 1)
return distance_matrix
class API_Net(nn.Module):
def __init__(self):
super(API_Net, self).__init__()
resnet101 = models.resnet101(pretrained=True)
layers = list(resnet101.children())[:-2]
self.conv = nn.Sequential(*layers)
self.avg = nn.AvgPool2d(kernel_size=14, stride=1)
# 互向量生成
self.map1 = nn.Linear(2048 * 2, 512)
self.map2 = nn.Linear(512, 2048)
self.fc = nn.Linear(2048, 200)
self.drop = nn.Dropout(p=0.5)
self.sigmoid = nn.Sigmoid()
def forward(self, images, targets=None, flag='train'):
conv_out = self.conv(images) # b*c*h*w
pool_out = self.avg(conv_out).squeeze() # b*c*1*1 -> b*c
if flag == 'train':
intra_pairs, inter_pairs, intra_labels, inter_labels = self.get_pairs(pool_out, targets)
features1 = torch.cat([pool_out[intra_pairs[:, 0]], pool_out[inter_pairs[:, 0]]], dim=0) # 样本,样本, 2b * c
features2 = torch.cat([pool_out[intra_pairs[:, 1]], pool_out[inter_pairs[:, 1]]],
dim=0) # 类外最像样本,类内最像样本, 2b * c
labels1 = torch.cat([intra_labels[:, 0], inter_labels[:, 0]], dim=0)
labels2 = torch.cat([intra_labels[:, 1], inter_labels[:, 1]], dim=0)
mutual_features = torch.cat([features1, features2],
dim=1) # dim=1拼接,2b * 2c,前b个是(样本,类外最像样本),后b个是(样本,类内最像样本),
map1_out = self.map1(mutual_features)
map2_out = self.drop(map1_out)
map2_out = self.map2(map2_out) # 生成互向量
gate1 = torch.mul(map2_out, features1)
gate1 = self.sigmoid(gate1)
gate2 = torch.mul(map2_out, features2)
gate2 = self.sigmoid(gate2) # 生成门向量
# 成对交互
features1_self = torch.mul(gate1, features1) + features1
features1_other = torch.mul(gate2, features1) + features1
features2_self = torch.mul(gate2, features2) + features2
features2_other = torch.mul(gate1, features2) + features2
logit1_self = self.fc(self.drop(features1_self))
logit1_other = self.fc(self.drop(features1_other))
logit2_self = self.fc(self.drop(features2_self))
logit2_other = self.fc(self.drop(features2_other))
return logit1_self, logit1_other, logit2_self, logit2_other, labels1, labels2
elif flag == 'val':
return self.fc(pool_out)
def get_pairs(self, embeddings, labels):
distance_matrix = pdist(embeddings).detach().cpu().numpy() # b*b
labels = labels.detach().cpu().numpy().reshape(-1, 1) # b*1
num = labels.shape[0] # 样本数
dia_inds = np.diag_indices(num) # (array([0, 1, 2, ..., num]), array([0, 1, 2, ..., num])
lb_eqs = (labels == labels.T) # 同一类标签的坐标
lb_eqs[dia_inds] = False # 自己不能和自己成对
dist_same = distance_matrix.copy()
dist_same[lb_eqs == False] = np.inf # 不能和自己匹配的举例是无穷
intra_idxs = np.argmin(dist_same, axis=1) # 每个样本的同一个类中的最接近的坐标
dist_diff = distance_matrix.copy()
lb_eqs[dia_inds] = True
dist_diff[lb_eqs == True] = np.inf
inter_idxs = np.argmin(dist_diff, axis=1) # 每个样本的不同类中的最接近的坐标
# 组对
intra_pairs = np.zeros([embeddings.shape[0], 2])
inter_pairs = np.zeros([embeddings.shape[0], 2])
intra_labels = np.zeros([embeddings.shape[0], 2])
inter_labels = np.zeros([embeddings.shape[0], 2])
for i in range(embeddings.shape[0]):
# 不同类
intra_labels[i, 0] = labels[i]
intra_labels[i, 1] = labels[intra_idxs[i]]
intra_pairs[i, 0] = i
intra_pairs[i, 1] = intra_idxs[i]
# 同一类
inter_labels[i, 0] = labels[i]
inter_labels[i, 1] = labels[inter_idxs[i]]
inter_pairs[i, 0] = i
inter_pairs[i, 1] = inter_idxs[i]
intra_labels = torch.from_numpy(intra_labels).long().to(device)
intra_pairs = torch.from_numpy(intra_pairs).long().to(device)
inter_labels = torch.from_numpy(inter_labels).long().to(device)
inter_pairs = torch.from_numpy(inter_pairs).long().to(device)
return intra_pairs, inter_pairs, intra_labels, inter_labels
pytorch
实现平衡采样:
class BalancedBatchSampler(BatchSampler):
def __init__(self, dataset, n_classes, n_samples):
self.labels = dataset.labels # dataset自定义的属性
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.dataset = dataset
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < len(self.dataset):
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return len(self.dataset) // self.batch_size
"""
使用方法:
"""
train_sampler = BalancedBatchSampler(train_dataset, n_classes, n_samples)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_sampler=train_sampler,
num_workers=args.workers, pin_memory=True)