近期在学习See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification的pytorch版本代码地址
本文对utils.py进行解读,由于本人是小白,理解错误的地方请批评指正。本文件主要是各种功能函数,如各种评价指标,设置检测点,实现attention crop和attention drop,数据增广等功能。
"""Utils
Created: Nov 11,2019 - Yuchong Gu
Revised: Dec 03,2019 - Yuchong Gu
"""
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
##############################################
# Center Loss for Attention Regularization
##############################################
class CenterLoss(nn.Module):
def __init__(self):
super(CenterLoss, self).__init__()
self.l2_loss = nn.MSELoss(reduction='sum')
def forward(self, outputs, targets):
return self.l2_loss(outputs, targets) / outputs.size(0)
##################################
# Metric
##################################
class Metric(object):
pass
class AverageMeter(Metric):
def __init__(self, name='loss'):
self.name = name
self.reset()
def reset(self):
self.scores = 0.
self.total_num = 0.
def __call__(self, batch_score, sample_num=1):
self.scores += batch_score
self.total_num += sample_num
return self.scores / self.total_num
class TopKAccuracyMetric(Metric):
def __init__(self, topk=(1,)):
self.name = 'topk_accuracy'
self.topk = topk
self.maxk = max(topk)
self.reset()
def reset(self):
self.corrects = np.zeros(len(self.topk))
self.num_samples = 0.
def __call__(self, output, target):
"""Computes the precision@k for the specified values of k"""
self.num_samples += target.size(0)
_, pred = output.topk(self.maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
for i, k in enumerate(self.topk):
correct_k = correct[:k].view(-1).float().sum(0)
self.corrects[i] += correct_k.item()
return self.corrects * 100. / self.num_samples
##################################
# Callback
##################################
class Callback(object):
def __init__(self):
pass
def on_epoch_begin(self):
pass
def on_epoch_end(self, *args):
pass
class ModelCheckpoint(Callback):
def __init__(self, savepath, monitor='val_topk_accuracy', mode='max'):
self.savepath = savepath
self.monitor = monitor
self.mode = mode
self.reset()
super(ModelCheckpoint, self).__init__()
def reset(self):
if self.mode == 'max':
self.best_score = float('-inf')
else:
self.best_score = float('inf')
def set_best_score(self, score):
if isinstance(score, np.ndarray):
self.best_score = score[0]
else:
self.best_score = score
def on_epoch_begin(self):
pass
def on_epoch_end(self, logs, net, **kwargs):
current_score = logs[self.monitor]
if isinstance(current_score, np.ndarray):
current_score = current_score[0]
if (self.mode == 'max' and current_score > self.best_score) or \
(self.mode == 'min' and current_score < self.best_score):
self.best_score = current_score
if isinstance(net, torch.nn.DataParallel):
state_dict = net.module.state_dict()
else:
state_dict = net.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].cpu()
if 'feature_center' in kwargs:
feature_center = kwargs['feature_center']
feature_center = feature_center.cpu()
torch.save({
'logs': logs,
'state_dict': state_dict,
'feature_center': feature_center}, self.savepath)
else:
torch.save({
'logs': logs,
'state_dict': state_dict}, self.savepath)
#########################################################################################################################
# augment function attention crop 和 attention drop 重点代码 重点代码
########################################################################################################################
def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1):
batches, _, imgH, imgW = images.size()# B,C,H,W
if mode == 'crop':#attention crop
crop_images = []#用于存放crop的图像
for batch_index in range(batches):#当前处理第batch_index张图片
# attention_map attention_map[:, :1, :, :]
atten_map = attention_map[batch_index:batch_index + 1]#提取第batch_index张图片的attention_map记为atten_map
if isinstance(theta, tuple):
theta_c = random.uniform(*theta) * atten_map.max()#处理atten_map最大的得到阀值theta_c
else:
theta_c = theta * atten_map.max()
crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c#大于阀值的置为1.[n,c,h,w]
nonzero_indices = torch.nonzero(crop_mask[0, 0, ...])
#取出第batch_index张图,第0个crop_maskp的为ture的位置,因为取出的只有1张图片,是1维的,第一个位置只能是0
#第二个0,是因为网络生成的两个attention_map,通道0用来crop。
height_min = max(int(nonzero_indices[:, 0].min().item() - padding_ratio * imgH), 0)
#height_min肯定是大于0的数,为了选择盖住所选的区域,要比小更小
height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH)
width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0)
width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW)
crop_images.append(
F.upsample_bilinear( [batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max],
size=(imgH, imgW)))
crop_images = torch.cat(crop_images, dim=0)
return crop_images
elif mode == 'drop':#attention drop attention_map[:, 1:, :, :]
drop_masks = []
for batch_index in range(batches):
atten_map = attention_map[batch_index:batch_index + 1]
if isinstance(theta, tuple):
theta_d = random.uniform(*theta) * atten_map.max()
else:
theta_d = theta * atten_map.max()
drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d)
drop_masks = torch.cat(drop_masks, dim=0)
drop_images = images * drop_masks.float()#在images中删除dropimages部分
return drop_images
else:
raise ValueError('Expected mode in [\'crop\', \'drop\'], but received unsupported augmentation method %s' % mode)
########################################################################################################################
########################################################################################################################
# transform in dataset
##################################
def get_transform(resize, phase='train'):
if phase == 'train':
return transforms.Compose([
transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),#裁剪
transforms.RandomCrop(resize),#1.随机裁剪———————————————————————————
transforms.RandomHorizontalFlip(0.5),#依概率p水平翻转————————————————————————————————
transforms.ColorJitter(brightness=0.126, saturation=0.5),#修改亮度、对比度和饱和度————————————————————————-
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])#标准化
])
else:
return transforms.Compose([
transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))),
transforms.CenterCrop(resize),#中心裁剪
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])