近期在学习See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification的pytorch版本代码地址
本文对eval.py进行解读,由于本人是小白,理解错误的地方请批评指正。
"""EVALUATION
Created: Nov 22,2019 - Yuchong Gu
Revised: Dec 03,2019 - Yuchong Gu
"""
import os
import logging
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import config
from models import WSDAN
from datasets import get_trainval_datasets
from utils import TopKAccuracyMetric, batch_augment
# GPU settings
assert torch.cuda.is_available()
os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
# visualize
visualize = config.visualize
savepath = config.eval_savepath
if visualize:
os.makedirs(savepath, exist_ok=True)
ToPILImage = transforms.ToPILImage()
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def generate_heatmap(attention_maps):
heat_attention_maps = []
heat_attention_maps.append(attention_maps[:, 0, ...]) # R
heat_attention_maps.append(attention_maps[:, 0, ...] * (attention_maps[:, 0, ...] < 0.5).float() + \
(1. - attention_maps[:, 0, ...]) * (attention_maps[:, 0, ...] >= 0.5).float()) # G
heat_attention_maps.append(1. - attention_maps[:, 0, ...]) # B
return torch.stack(heat_attention_maps, dim=1)
def main():
logging.basicConfig(
format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
level=logging.INFO)
warnings.filterwarnings("ignore")
try:
ckpt = config.eval_ckpt
except:
logging.info('Set ckpt for evaluation in config.py')
return
##################################
# Dataset for testing
##################################
_, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
num_workers=2, pin_memory=True)
##################################
# Initialize model
##################################
net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net)#使用测试模式的WSDAN
# Load ckpt and get state_dict
checkpoint = torch.load(ckpt)
state_dict = checkpoint['state_dict']
# Load weights
net.load_state_dict(state_dict)
logging.info('Network loaded from {}'.format(ckpt))
##################################
# use cuda
##################################
net.to(device)
if torch.cuda.device_count() > 1:
net = nn.DataParallel(net)
##################################
# Prediction
##################################
raw_accuracy = TopKAccuracyMetric(topk=(1, 5))
ref_accuracy = TopKAccuracyMetric(topk=(1, 5))
raw_accuracy.reset()
ref_accuracy.reset()
net.eval()
with torch.no_grad():
pbar = tqdm(total=len(test_loader), unit=' batches')
pbar.set_description('Validation')
for i, (X, y) in enumerate(test_loader):
X = X.to(device)
y = y.to(device)
# WS-DAN 骨干网返回粗预测结果
y_pred_raw, _, attention_maps = net(X)
# Augmentation with crop_mask
# attention crop 返回crop 精细预测结果,粗细结合,得到最佳。
############################################################重点代码###################################################
crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05)
y_pred_crop, _, _ = net(crop_image)
y_pred = (y_pred_raw + y_pred_crop) / 2.
######################################################################################################################
if visualize:
# reshape attention maps
attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3)))
attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item())
# get heat attention maps #生成热力图
heat_attention_maps = generate_heatmap(attention_maps)
# raw_image, heat_attention, raw_attention
raw_image = X.cpu() * STD + MEAN
heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5#热力图合并到原图上
raw_attention_image = raw_image * attention_maps#attention_maps关注的地方
for batch_idx in range(X.size(0)):
rimg = ToPILImage(raw_image[batch_idx])
raimg = ToPILImage(raw_attention_image[batch_idx])
haimg = ToPILImage(heat_attention_image[batch_idx])
rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx)))
raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx)))
haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx)))
# Top K
epoch_raw_acc = raw_accuracy(y_pred_raw, y)
epoch_ref_acc = ref_accuracy(y_pred, y)
# end of this batch
batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format(
epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1])
pbar.update()
pbar.set_postfix_str(batch_info)
pbar.close()
if __name__ == '__main__':
main()