摘要
本文复现了WS-DAN(Weakly Supervised Data Augmentation Network)在CUB-200-2011上的运行效果,测试集准确率达到了论文(See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification)所述的top-1 89.4%。下文将介绍论文思路和分析代码逻辑。
目录
1. inception v3 在CUB200-2011上的效果
WS-DAN是基于inception v3模型的,首先查看其运行效果。
1.1数据集预处理
CUB200-2011数据集是没有划分训练集和验证集的,但是官方给出了参考文件train_test_split.txt、image_class_labels.txt、bounding_boxes.txt。train_test_split.txt用于划分训练集和验证集,5994张训练集、5794张验证集。image_class_labels.txt存放标签。bounding_boxes.txt是物体边界框。利用下方脚本将images分成train和test两个文件夹。
import os
import pandas as pd
from PIL import Image
from shutil import copyfile
def makedir(path):
"""
if path does not exist in the file system, create it
"""
if not os.path.exists(path):
os.makedirs(path)
# set paths
rootpath = './dataset/CUB200_origin/'
imgspath = rootpath + 'images/'
trainpath = 'dataset/cub200_cropped/train/'
testpath = 'dataset/cub200_cropped/test/'
# read img names, bounding_boxes
names = pd.read_table(rootpath + 'images.txt', delimiter=' ', names=['id', 'name'])
names = names.to_numpy()
boxs = pd.read_table(rootpath + 'bounding_boxes.txt', delimiter=' ', names=['id', 'x', 'y', 'width', 'height'])
boxs = boxs.to_numpy()
# crop imgs 裁剪覆盖原图
for i in range(11788):
im = Image.open(imgspath + names[i][1])
im = im.crop((boxs[i][1], boxs[i][2], boxs[i][1] + boxs[i][3], boxs[i][2] + boxs[i][4]))
im.save(imgspath + names[i][1], quality=95)
print('{} imgs cropped and saved.'.format(i + 1))
print('All Done.')
# mkdir for cropped imgs 创建分类文件夹
folders = pd.read_table(rootpath + 'classes.txt', delimiter=' ', names=['id', 'folder'])
folders = folders.to_numpy()
for i in range(200):
makedir(trainpath + folders[i][1])
makedir(testpath + folders[i][1])
# split imgs 分类图片,按照训练集和验证集
labels = pd.read_table(rootpath + 'train_test_split.txt', delimiter=' ', names=['id', 'label'])
labels = labels.to_numpy()
for i in range(11788):
if labels[i][1] == 1:
copyfile(imgspath + names[i][1], trainpath + names[i][1])
else:
copyfile(imgspath + names[i][1], testpath + names[i][1])
print('{} imgs splited.'.format(i + 1))
print('All Done.')
脚本中还利用了bounding_boxes.txt这个文件裁剪数据集,这也是一种数据增强的方式(突出目标),也就是WS-DAN题目所述——看得更近。只不过这里是人工的,而WS-DAN是学习出来的。
1.2 inception v3 训练与验证
transform = transforms.Compose([
# 对图像进行随机的裁剪crop以后再resize成固定大小(299*299)
transforms.RandomResizedCrop(299),
# 随机旋转20度(顺时针和逆时针)
transforms.RandomRotation(20),
# 随机水平翻转
transforms.RandomHorizontalFlip(p=0.5),
# 将数据转换为tensor
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.226, 0.224, 0.225))
])
在模型最后增加全连接层,损失函数:交叉熵,优化器:随机梯度下降。
train_loader = DataLoader(dataset=train_set, batch_size=50, shuffle=True, num_workers=0, drop_last=False)
test_loader = DataLoader(dataset=test_set, batch_size=50, shuffle=True, num_workers=0, drop_last=False)
# 加载模型
inception_v3_model = models.inception_v3(pretrained=True)
inception_v3_model.fc = nn.Linear(2048, 200)
# inception_v3_model.load_state_dict(torch.load("/kaggle/input/", map_location=device))
inception_v3_model.to(device)
# 损失函数
loss = nn.CrossEntropyLoss() # 交叉熵
loss.to(device)
# 优化,随机梯度下降
LR = 0.01
optim = torch.optim.SGD(inception_v3_model.parameters(), LR, momentum=0.9)
1.3 改进数据集处理
1.2中准确率低的原因是transform中的transforms.RandomResizedCrop(299)语句,该语句对输入图像进行随机裁剪,由于数据集在1.1中已经进行了目标裁剪,所以此处的随机裁剪会损失目标的大部分特征,使得网络难以识别。删除该语句后准确率有了大幅提升。
2. WS-DAN模型
2.1 训练和验证模型
2.1.1 BAP
class BAP(nn.Module):
def __init__(self, **kwargs):
super(BAP, self).__init__()
def forward(self, feature_maps, attention_maps):
feature_shape = feature_maps.size() # N x 768 x 17 x 17
attention_shape = attention_maps.size() # N x 32(num_parts) x 17 x 17
# print(feature_shape, attention_shape)
phi_I = torch.einsum('imjk,injk->imn', (attention_maps, feature_maps)) # 矩阵对应点相乘相加 全局平均池化 N x 32 x 768
phi_I = torch.div(phi_I, float(attention_shape[2] * attention_shape[3])) # 除法,除上17*17,归一化
phi_I = torch.mul(torch.sign(phi_I), torch.sqrt(torch.abs(phi_I) + 1e-12))
phi_I = phi_I.view(feature_shape[0], -1) # 相当于resize N*(32*768)
raw_features = torch.nn.functional.normalize(phi_I, dim=-1) # 除以每个样本的范数 N*(32*768)
pooling_features = raw_features*100
# print(pooling_features.shape)
return raw_features, pooling_features
代码中,点乘后直接相加然后进行归一化处理,考虑到归一化后的raw_features太小,用放大100倍的pooling_features进行后续处理,raw_features用于进行center loss的计算,后面会介绍。
2.1.2 crop
让我们再次回到Figure 3,还剩下一个部分,注意力图引导数据增强(紫色部分),该部分是论文的主要创新点。在attention maps中随机抽取两张图,这个随机不是完全的随机,是有概率的随机,由part_weights进行控制。首先对32张注意力图进行全局平均池化变成32个数值,对32个数值进行标准化(每个数除以总和)得到32个概率值,由这个概率值抽取两张attention maps。每个batch size有16张图,每张图对应32张attention maps,由抽取出的两张attention maps进行数据增强。
A
k
∗
=
A
k
−
m
i
n
(
A
k
)
m
a
x
(
A
k
)
−
m
i
n
(
A
k
)
(
1
)
A_k^* = \frac{A_k - min(A_k)}{max(A_k) - min(A_k)}\quad\quad\quad(1)
Ak∗=max(Ak)−min(Ak)Ak−min(Ak)(1)
然后对选择一个阈值
θ
c
∈
[
0
,
1
]
(
2
)
\theta_c \in [0,1]\quad\quad\quad(2)
θc∈[0,1](2)
大于阈值置1,小于阈值置0
C
k
(
i
,
j
)
=
{
1
i
f
A
k
∗
(
i
,
j
)
>
θ
c
0
o
t
h
e
r
w
i
s
e
.
(
3
)
C_k(i,j) = \begin{cases} 1 & i f\quad A_k^*(i,j)> \theta_c\\ 0 & otherwise. \\ \end{cases}\quad\quad\quad(3)
Ck(i,j)={10ifAk∗(i,j)>θcotherwise.(3)
最后找到一个边界框可以覆盖所有为1的数值,用这个边界框的索引裁剪输入图片生成第一张数据增强图片,这个过程与1.1中的处理方式类似,但是此时不需要人工干预,网络自己可以生成。
# --------create crop imgs
mask = attention_map[selected_index, :, :] # 取出一张注意力图
# mask = (mask-mask.min())/(mask.max()-mask.min()) 没有按照论文的套路
threshold = random.uniform(0.4, 0.6)
# 注意力图中大于最大值一半的点索引,返回两行,第一行为行索引,第二行为列索引
itemindex = np.where(mask.cpu() >= mask.cpu().max()*threshold)
# print(itemindex.shape)
# itemindex = torch.nonzero(mask >= threshold*mask.max())
padding_h = int(0.1*H) # 51.2
padding_w = int(0.1*W) # 51.2
height_min = itemindex[0].min() # 找出行中最小值的索引
height_min = max(0, height_min-padding_h) # 往前退51.2
height_max = itemindex[0].max() + padding_h # 找到最大值索引,往后退51.2
width_min = itemindex[1].min()
width_min = max(0, width_min-padding_w)
width_max = itemindex[1].max() + padding_w
# print('numpy',height_min,height_max,width_min,width_max)
# 在每一张(16)输入图像中裁剪出注意力图数值较大的区域
out_img = input_tensor[i][:, height_min:height_max, width_min:width_max].unsqueeze(0) # 升一个维度,最高维度留给batch
out_img = torch.nn.functional.interpolate(out_img, size=(W, H), mode='bilinear', align_corners=True) # 恢复大小
out_img = out_img.squeeze(0) # 插值后再升维
ret_imgs.append(out_img)
代码的处理方式略有不同,没有严格按照公式(1)和(3)。首先阈值是随机生成的,直接通过np.where找出边界框索引,没有使用公式(1)和(3)。代码中还做了一个padding,扩充了边界框的大小。
2.1.3 drop
回到Figure 5,第二张注意力图用于drop,挖去注意力高的特征,让网络可以注意到目标的其他特征,处理方式与crop相反。
C
k
(
i
,
j
)
=
{
0
i
f
A
k
∗
(
i
,
j
)
>
θ
d
1
o
t
h
e
r
w
i
s
e
.
(
4
)
C_k(i,j) = \begin{cases} 0 & i f\quad A_k^*(i,j)> \theta_d\\ 1 & otherwise. \\ \end{cases}\quad\quad\quad(4)
Ck(i,j)={01ifAk∗(i,j)>θdotherwise.(4)
# --------create drop imgs
mask2 = attention_map[selected_index2:selected_index2 + 1, :, :] # 取出一张注意力图
threshold = random.uniform(0.2, 0.5)
mask2 = (mask2 < threshold * mask2.max()).float() # 小于最大值阈值的值1
masks.append(mask2)
# bboxes = np.asarray(bboxes, np.float32)
crop_imgs = torch.stack(ret_imgs)
masks = torch.stack(masks)
drop_imgs = input_tensor*masks # 输入图像与mask相乘,掩盖注意力高的地方
阈值是随机选取的,这里是利用图像乘法用0掩盖注意力高的部分。
2.2 Loss
如Figure 3和Figure 5,图片输入网络后产生注意力图,注意力图指导输入图片产生两张数据增强图片,也就是一共三张图片输入网络,产生loss1、loss2、loss3,最终的loss如公式(5)。
l
o
s
s
=
l
o
s
s
1
+
l
o
s
s
2
+
l
o
s
s
3
3
+
c
e
n
t
e
r
_
l
o
s
s
(
5
)
loss=\frac{loss1+loss2+loss3}{3}+center\_loss\quad\quad\quad(5)
loss=3loss1+loss2+loss3+center_loss(5)
2.2.1 center loss
公式(5)中还多了一部分 center loss。center loss用于指导每次训练产生的注意力图尽可能相似。论文中的center loss借鉴了人脸识别,融合center loss是该论文的创新点,也是反映了题目中的弱监督。先来看center loss是怎么产生的。
L
A
=
∑
k
=
1
M
∣
∣
f
k
−
c
k
∣
∣
2
(
6
)
L_A=\sum\limits_{k=1}^{M}||f_k-c_k||^2\quad\quad\quad(6)
LA=k=1∑M∣∣fk−ck∣∣2(6)
M是batch size,fk是BAP求出的24576个数值的feature,ck就是我们的特征中心(初始化为0),通过batch不断修正ck。
c
k
←
c
k
+
β
(
f
k
−
c
k
)
(
7
)
c_k←c_k+\beta(f_k-c_k)\quad\quad\quad(7)
ck←ck+β(fk−ck)(7)
那么fk跟attention maps有什么关系呢?如Figure 4,fk是由feature maps和attention maps共同生成的,而attention maps是在feature maps中选取出来的,所以要想attention maps相似,feature maps也要相似,center loss就在参数上就采用了两者的BAP。
2.3 测试模型
测试与训练和验证的过程略有不同。
for i in range(batch_size):
attention_map = attention_maps[i]
# print(attention_map.shape)
# 测试时候对所有特征图取平均,包含所有特征
mask = attention_map.mean(dim=0)
# print(type(mask))
# mask = (mask-mask.min())/(mask.max()-mask.min())
# threshold = random.uniform(0.4, 0.6)
threshold = 0.1
max_activate = mask.max()
min_activate = threshold * max_activate
itemindex = torch.nonzero(mask >= min_activate) #效果与where类似
padding_h = int(0.05*H)
padding_w = int(0.05*W)
height_min = itemindex[:, 0].min()
height_min = max(0,height_min-padding_h)
height_max = itemindex[:, 0].max() + padding_h
width_min = itemindex[:, 1].min()
width_min = max(0,width_min-padding_w)
width_max = itemindex[:, 1].max() + padding_w
# print(height_min,height_max,width_min,width_max)
out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
out_img = out_img.squeeze(0)
# print(out_img.shape)
ret_imgs.append(out_img)
2.4 accuracy
模型一个batch输入3 x batch size张图片进行训练,而在代码中,只采用了output1来统计训练和验证的accuracy。
prec1, prec5 = accuracy(output1, target, topk=(1, 5))
但在测试时有所不同,测试一个batch输入2 x batch size张图片,采用两个output取平均来计算accuracy。
output = (F.softmax(output1, dim=-1)+F.softmax(output2, dim=-1))/2
# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
2.5 实验
Start epoch 68 ==========,lr=0.000000
Epoch: [68][0/375] Time 2.643 (2.643) Data 1.803 (1.803) Loss 0.1370 (0.1370) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.033498723059892654 0.014470446854829788 0.025759508833289146 0.11245858669281006
Epoch: [68][100/375] Time 0.898 (0.918) Data 0.000 (0.020) Loss 0.1923 (0.2394) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.03685980290174484 0.006798317655920982 0.08380372077226639 0.14980196952819824
Epoch: [68][200/375] Time 0.896 (0.909) Data 0.000 (0.012) Loss 0.1688 (0.2400) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.05070505663752556 0.029433125630021095 0.04474890977144241 0.1271676868200302
Epoch: [68][300/375] Time 0.898 (0.906) Data 0.000 (0.009) Loss 0.2484 (0.2416) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.06296264380216599 0.013908100314438343 0.24796684086322784 0.14007914066314697
Test: [0/363] Time 1.313 (1.313) Loss 0.8828 (0.8828) Prec@1 93.750 (93.750) Prec@5 100.000 (100.000)
Test: [100/363] Time 0.359 (0.373) Loss 1.0795 (1.0499) Prec@1 87.500 (87.871) Prec@5 100.000 (97.834)
Test: [200/363] Time 0.357 (0.367) Loss 0.5304 (1.0939) Prec@1 87.500 (87.034) Prec@5 100.000 (97.668)
Test: [300/363] Time 0.361 (0.365) Loss 1.0650 (1.1211) Prec@1 81.250 (86.981) Prec@5 93.750 (97.550)
* Prec@1 87.832 Prec@5 97.618
Start epoch 69 ==========,lr=0.000000
Epoch: [69][0/375] Time 1.991 (1.991) Data 1.223 (1.223) Loss 0.3047 (0.3047) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.05962998792529106 0.0023914300836622715 0.31575942039489746 0.17878621816635132
Epoch: [69][100/375] Time 0.897 (0.911) Data 0.000 (0.015) Loss 0.2376 (0.2468) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.09725010395050049 0.06658463180065155 0.056900348514318466 0.1640474945306778
Epoch: [69][200/375] Time 0.896 (0.906) Data 0.000 (0.009) Loss 0.2108 (0.2511) Prec@1 100.000 (99.969) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.06277890503406525 0.03270210325717926 0.042565133422613144 0.1647663414478302
Epoch: [69][300/375] Time 0.897 (0.904) Data 0.000 (0.007) Loss 0.4613 (0.2520) Prec@1 100.000 (99.979) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.06680921465158463 0.015883391723036766 0.9120850563049316 0.12972146272659302
Test: [0/363] Time 1.381 (1.381) Loss 1.1078 (1.1078) Prec@1 93.750 (93.750) Prec@5 100.000 (100.000)
Test: [100/363] Time 0.363 (0.374) Loss 1.0061 (1.0701) Prec@1 87.500 (88.119) Prec@5 100.000 (97.649)
Test: [200/363] Time 0.359 (0.368) Loss 0.4331 (1.1050) Prec@1 93.750 (87.469) Prec@5 100.000 (97.544)
Test: [300/363] Time 0.360 (0.367) Loss 1.1410 (1.1201) Prec@1 81.250 (87.396) Prec@5 93.750 (97.529)
* Prec@1 88.091 Prec@5 97.566
Start epoch 70 ==========,lr=0.000000
Epoch: [70][0/375] Time 2.134 (2.134) Data 1.318 (1.318) Loss 0.3534 (0.3534) Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
loss1,loss2,loss3,feature_center_loss 0.050787247717380524 0.04547658935189247 0.5441681742668152 0.1399531215429306
Test: [0/363] Prec@1 100.000 (100.000) Prec@5 100.000 (100.000)
Test: [100/363] Prec@1 93.750 (89.047) Prec@5 100.000 (98.144)
Test: [200/363] Prec@1 93.750 (88.588) Prec@5 100.000 (97.886)
Test: [300/363] Prec@1 81.250 (88.746) Prec@5 93.750 (97.799)
* Prec@1 89.472 Prec@5 97.825
跑了70个epoch,在测试集上top-1准确率达到了89.4%,复现了论文的准确率。
参考资料
Hu, Tao, Honggang Qi, Qingming Huang和Yan Lu. 《See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification》. arXiv, 2019年3月23日. https://doi.org/10.48550/arXiv.1901.09891.
https://github.com/wvinzh/WS_DAN_PyTorch
Szegedy, Christian, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens和Zbigniew Wojna. 《Rethinking the Inception Architecture for Computer Vision》. 收入 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2818–26. Las Vegas, NV, USA: IEEE, 2016. https://doi.org/10.1109/CVPR.2016.308.