目录
背景介绍
本文架构
具体算法实现细节
背景介绍
原文链接:WSDAN论文链接
以往的数据增强办法:随机图像裁剪;图像旋转;色彩失真;
细粒度识别存在的问题:训练数据不足;类内方差过大(不同的姿势);类间方差过小(外貌只有细微差别);如果需要人为标识discriminative的部分,则会有额外的cost
本文架构
具体算法实现细节
1. Inception实现细节:Inception的大体架构及代码
2. BAP实现细节
class BAP(nn.Module):
# 原文3.1.2
def __init__(self, **kwargs):
super(BAP, self).__init__()
def forward(self,feature_maps,attention_maps):
feature_shape = feature_maps.size() ## 12*768*26*26*
attention_shape = attention_maps.size() ## 12*num_parts*26*26,12是batch
# print(feature_shape,attention_shape)
#https://zhuanlan.zhihu.com/p/44954540
# 这个好像就是矩阵点乘
phi_I = torch.einsum('imjk,injk->imn', (attention_maps, feature_maps)) ## 12*32*768,但每一维度都多了26*26的倍数
phi_I = torch.div(phi_I, float(attention_shape[2] * attention_shape[3]))
phi_I = torch.mul(torch.sign(phi_I), torch.sqrt(torch.abs(phi_I) + 1e-12)) # 每个数变为其平方根
phi_I = phi_I.view(feature_shape[0],-1) #12*(32*768)
raw_features = torch.nn.functional.normalize(phi_I, dim=-1) ##12*(32*768),按featur maps的数目取范式
pooling_features = raw_features*100
# print(pooling_features.shape)
return raw_features,pooling_features
3. Attention crop
def attention_crop(attention_maps,input_image):
# start = time.time()
B,N,W,H = input_image.shape
input_tensor = input_image
batch_size, num_parts, height, width = attention_maps.shape
# 下面对应3.2.1
attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear') # 上采样,复制一份免得对backpro有影响
part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1) # 在W,H维度进行Pooling
part_weights = torch.add(torch.sqrt(part_weights),1e-12) # 缩放 [batch_size,pool(W*H)*num_parts]
part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
part_weights = part_weights.numpy()
ret_imgs = []
# print(part_weights[3])
for i in range(batch_size):
# 对应3.2.2
attention_map = attention_maps[i]
part_weight = part_weights[i]
# https://www.cnblogs.com/cloud-ken/p/9931273.html
# 按照p在[0,num_parts]中取1个数
selected_index = np.random.choice(
np.arange(0, num_parts), 1, p=part_weight)[0]
mask = attention_map[selected_index, :, :]
# print(type(mask))
# mask = (mask-mask.min())/(mask.max()-mask.min())
threshold = random.uniform(0.4, 0.6) # 随机生成阈值
# threshold = 0.5
# itemindex = np.where(mask >= threshold)
itemindex = np.where(mask >= mask.max() * threshold) # 返回索引https://www.cnblogs.com/massquantity/p/8908859.html
# itemindex = torch.nonzero(mask >= threshold)
padding_h = int(0.1*H)
padding_w = int(0.1*W)
height_min = itemindex[0].min()
height_min = max(0,height_min-padding_h)
height_max = itemindex[0].max() + padding_h
width_min = itemindex[1].min()
width_min = max(0,width_min-padding_w)
width_max = itemindex[1].max() + padding_w
out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0) # 随机裁剪
out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True) # 放大
out_img = out_img.squeeze(0)
# print(out_img.shape)
ret_imgs.append(out_img)
ret_imgs = torch.stack(ret_imgs)
return ret_imgs
4. Attention Dropping
def attention_drop(attention_maps,input_image):
B,N,W,H = input_image.shape
input_tensor = input_image
batch_size, num_parts, height, width = attention_maps.shape
attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1)
part_weights = torch.add(torch.sqrt(part_weights),1e-12)
part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu().numpy()
# attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear', align_corners=True)
# print(part_weights.shape)
masks = []
for i in range(batch_size):
attention_map = attention_maps[i].detach()
part_weight = part_weights[i]
selected_index = np.random.choice(
np.arange(0, num_parts), 1, p=part_weight)[0]
mask = attention_map[selected_index:selected_index + 1, :, :]
# soft mask
# threshold = random.uniform(0.2, 0.5)
# threshold = 0.5
# mask = (mask-mask.min())/(mask.max()-mask.min())
# mask = (mask < threshold).float()
threshold = random.uniform(0.2, 0.5)
mask = (mask < threshold * mask.max()).float()
masks.append(mask)
masks = torch.stack(masks)
# print(masks.shape)
ret = input_tensor*masks
return ret
5. test的过程
def mask2bbox(attention_maps,input_image):
# 这个是test的crop
input_tensor = input_image
B,C,H,W = input_tensor.shape
batch_size, num_parts, Hh, Ww = attention_maps.shape
attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear')
ret_imgs = []
# print(part_weights[3])
for i in range(batch_size):
attention_map = attention_maps[i]
# print(attention_map.shape)
mask = attention_map.mean(dim=0) # 唯一不同的是这里取mean
# print(type(mask))
# mask = (mask-mask.min())/(mask.max()-mask.min())
# threshold = random.uniform(0.4, 0.6)
threshold = 0.1
max_activate = mask.max()
min_activate = threshold * max_activate
itemindex = torch.nonzero(mask >= min_activate)
padding_h = int(0.05*H)
padding_w = int(0.05*W)
height_min = itemindex[:, 0].min()
height_min = max(0,height_min-padding_h)
height_max = itemindex[:, 0].max() + padding_h
width_min = itemindex[:, 1].min()
width_min = max(0,width_min-padding_w)
width_max = itemindex[:, 1].max() + padding_w
# print(height_min,height_max,width_min,width_max)
out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
out_img = out_img.squeeze(0)
# print(out_img.shape)
ret_imgs.append(out_img)
ret_imgs = torch.stack(ret_imgs)
# print(ret_imgs.shape)
return ret_imgs
# 为了便于观察区分,我把train和test的代码一并放在下面
class Engine():
# 用来实现train step的细节
def __init__(self,):
pass
def train(self,state,epoch):
batch_time = AverageMeter() # http://codingdict.com/sources/py/utils/5219.html
data_time = AverageMeter() # 好像是方便自动更新参数,但是没搞清楚机制
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
config = state['config']
print_freq = config.print_freq # 打印频率
model = state['model']
criterion = state['criterion']
optimizer = state['optimizer']
train_loader = state['train_loader']
model.train()
end = time.time()
for i, (img, label) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = label.cuda()
input = img.cuda()
# compute output
attention_maps, raw_features, output1 = model(input) #这个model确实输出的是这些
features = raw_features.reshape(raw_features.shape[0], -1)
feature_center_loss, center_diff = calculate_pooling_center_loss(
features, state['center'], target, alfa=config.alpha) # 好像是个计算某种损失函数
# update model.centers
state['center'][target] += center_diff
# compute refined loss
# img_drop = attention_drop(attention_maps,input)
# img_crop = attention_crop(attention_maps, input)
img_crop, img_drop = attention_crop_drop(attention_maps, input)
_, _, output2 = model(img_drop)
_, _, output3 = model(img_crop)
loss1 = criterion(output1, target)
loss2 = criterion(output2, target)
loss3 = criterion(output3, target)
loss = (loss1+loss2+loss3)/3 + feature_center_loss
# measure accuracy and record loss
prec1, prec5 = accuracy(output1, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward() # 反向计算grad
optimizer.step() # 利用optim进行优化
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# 每100次打印一次
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
print("loss1,loss2,loss3,feature_center_loss", loss1.item(), loss2.item(), loss3.item(),
feature_center_loss.item())
return top1.avg, losses.avg
def validate(self,state):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
config = state['config']
print_freq = config.print_freq
model = state['model']
val_loader = state['val_loader']
criterion = state['criterion']
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda()
input = input.cuda()
# forward
attention_maps, raw_features, output1 = model(input)
features = raw_features.reshape(raw_features.shape[0], -1)
feature_center_loss, _ = calculate_pooling_center_loss(
features, state['center'], target, alfa=config.alpha)
img_crop, img_drop = attention_crop_drop(attention_maps, input)
# img_drop = attention_drop(attention_maps,input)
# img_crop = attention_crop(attention_maps,input)
_, _, output2 = model(img_drop)
_, _, output3 = model(img_crop)
loss1 = criterion(output1, target)
loss2 = criterion(output2, target)
loss3 = criterion(output3, target)
# loss = loss1 + feature_center_loss
loss = (loss1+loss2+loss3)/3+feature_center_loss
# measure accuracy and record loss
prec1, prec5 = accuracy(output1, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, losses.avg
def test(self,val_loader, model, criterion):
top1 = AverageMeter()
top5 = AverageMeter()
print_freq = 100
# switch to evaluate mode
model.eval()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.cuda()
input = input.cuda()
# forward
attention_maps, _, output1 = model(input) # p1
refined_input = mask2bbox(attention_maps, input) # crop的结果p2,这里的mask2bbox和train中的crop无差别
_, _, output2 = model(refined_input)
output = (F.softmax(output1, dim=-1)+F.softmax(output2, dim=-1))/2
# measure accuracy and record loss
prec1, prec5 = accuracy(output, target, topk=(1, 5))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
if i % print_freq == 0:
print('Test: [{0}/{1}]\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader),
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg, top5.avg