小赵带你读论文系列--See Better Before Looking Closer: WSDAN

目录

 

背景介绍

本文架构

具体算法实现细节


背景介绍

原文链接:WSDAN论文链接

以往的数据增强办法:随机图像裁剪;图像旋转;色彩失真;

细粒度识别存在的问题:训练数据不足;类内方差过大(不同的姿势);类间方差过小(外貌只有细微差别);如果需要人为标识discriminative的部分,则会有额外的cost

本文架构

具体算法实现细节

1. Inception实现细节Inception的大体架构及代码

2. BAP实现细节

class BAP(nn.Module):
# 原文3.1.2
    def __init__(self,  **kwargs):
        super(BAP, self).__init__()
    def forward(self,feature_maps,attention_maps):
        feature_shape = feature_maps.size() ## 12*768*26*26*
        attention_shape = attention_maps.size() ## 12*num_parts*26*26,12是batch
        # print(feature_shape,attention_shape)
        #https://zhuanlan.zhihu.com/p/44954540
        # 这个好像就是矩阵点乘
        phi_I = torch.einsum('imjk,injk->imn', (attention_maps, feature_maps)) ## 12*32*768,但每一维度都多了26*26的倍数
        phi_I = torch.div(phi_I, float(attention_shape[2] * attention_shape[3]))
        phi_I = torch.mul(torch.sign(phi_I), torch.sqrt(torch.abs(phi_I) + 1e-12)) # 每个数变为其平方根
        phi_I = phi_I.view(feature_shape[0],-1) #12*(32*768)
        raw_features = torch.nn.functional.normalize(phi_I, dim=-1) ##12*(32*768),按featur maps的数目取范式
        pooling_features = raw_features*100
        # print(pooling_features.shape)
        return raw_features,pooling_features

3. Attention crop

def attention_crop(attention_maps,input_image):
    
    # start = time.time()
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    # 下面对应3.2.1
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear') # 上采样,复制一份免得对backpro有影响
    part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1) # 在W,H维度进行Pooling
    part_weights = torch.add(torch.sqrt(part_weights),1e-12) # 缩放 [batch_size,pool(W*H)*num_parts]
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu()
    part_weights = part_weights.numpy()
    ret_imgs = []
    # print(part_weights[3])
    for i in range(batch_size):
    # 对应3.2.2
        attention_map = attention_maps[i]
        part_weight = part_weights[i]
        # https://www.cnblogs.com/cloud-ken/p/9931273.html
        # 按照p在[0,num_parts]中取1个数
        selected_index = np.random.choice(
            np.arange(0, num_parts), 1, p=part_weight)[0]
        mask = attention_map[selected_index, :, :]
        # print(type(mask))
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        threshold = random.uniform(0.4, 0.6) # 随机生成阈值
        # threshold = 0.5
        # itemindex = np.where(mask >= threshold)
        itemindex = np.where(mask >= mask.max() * threshold) # 返回索引https://www.cnblogs.com/massquantity/p/8908859.html

        # itemindex = torch.nonzero(mask >= threshold)
        padding_h = int(0.1*H)
        padding_w = int(0.1*W)
        height_min = itemindex[0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[0].max() + padding_h
        width_min = itemindex[1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[1].max() + padding_w
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0) # 随机裁剪
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True) # 放大
        out_img = out_img.squeeze(0)
        # print(out_img.shape)
        ret_imgs.append(out_img)
    ret_imgs = torch.stack(ret_imgs)
    return ret_imgs

4. Attention Dropping


def attention_drop(attention_maps,input_image):
    B,N,W,H = input_image.shape
    input_tensor = input_image
    batch_size, num_parts, height, width = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps.detach(),size=(W,H),mode='bilinear')
    part_weights = F.avg_pool2d(attention_maps,(W,H)).reshape(batch_size,-1)
    part_weights = torch.add(torch.sqrt(part_weights),1e-12)
    part_weights = torch.div(part_weights,torch.sum(part_weights,dim=1).unsqueeze(1)).cpu().numpy()
    # attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear', align_corners=True)
    # print(part_weights.shape)
    masks = []
    for i in range(batch_size):
        attention_map = attention_maps[i].detach()
        part_weight = part_weights[i]
        selected_index = np.random.choice(
            np.arange(0, num_parts), 1, p=part_weight)[0]
        mask = attention_map[selected_index:selected_index + 1, :, :]

        # soft mask
        # threshold = random.uniform(0.2, 0.5)
        # threshold = 0.5
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        # mask = (mask < threshold).float()
        threshold = random.uniform(0.2, 0.5)
        mask = (mask < threshold * mask.max()).float()
        masks.append(mask)
    masks = torch.stack(masks)
    # print(masks.shape)
    ret = input_tensor*masks
    return ret

5. test的过程

def mask2bbox(attention_maps,input_image):
# 这个是test的crop
    input_tensor = input_image
    B,C,H,W = input_tensor.shape
    batch_size, num_parts, Hh, Ww = attention_maps.shape
    attention_maps = torch.nn.functional.interpolate(attention_maps,size=(W,H),mode='bilinear')
    ret_imgs = []
    # print(part_weights[3])
    for i in range(batch_size):
        attention_map = attention_maps[i]
        # print(attention_map.shape)
        mask = attention_map.mean(dim=0) # 唯一不同的是这里取mean
        # print(type(mask))
        # mask = (mask-mask.min())/(mask.max()-mask.min())
        # threshold = random.uniform(0.4, 0.6)
        threshold = 0.1
        max_activate = mask.max()
        min_activate = threshold * max_activate
        itemindex = torch.nonzero(mask >= min_activate)

        padding_h = int(0.05*H)
        padding_w = int(0.05*W)
        height_min = itemindex[:, 0].min()
        height_min = max(0,height_min-padding_h)
        height_max = itemindex[:, 0].max() + padding_h
        width_min = itemindex[:, 1].min()
        width_min = max(0,width_min-padding_w)
        width_max = itemindex[:, 1].max() + padding_w
        # print(height_min,height_max,width_min,width_max)
        out_img = input_tensor[i][:,height_min:height_max,width_min:width_max].unsqueeze(0)
        out_img = torch.nn.functional.interpolate(out_img,size=(W,H),mode='bilinear',align_corners=True)
        out_img = out_img.squeeze(0)
        # print(out_img.shape)
        ret_imgs.append(out_img)
    ret_imgs = torch.stack(ret_imgs)
    # print(ret_imgs.shape)
    return ret_imgs


# 为了便于观察区分,我把train和test的代码一并放在下面
class Engine():
# 用来实现train step的细节
    def __init__(self,):
        pass

    def train(self,state,epoch):
        batch_time = AverageMeter() # http://codingdict.com/sources/py/utils/5219.html
        data_time = AverageMeter() # 好像是方便自动更新参数,但是没搞清楚机制
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        config = state['config']
        print_freq = config.print_freq # 打印频率
        model = state['model']
        criterion = state['criterion']
        optimizer = state['optimizer']
        train_loader = state['train_loader']
        model.train()
        end = time.time()
        for i, (img, label) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            target = label.cuda()
            input = img.cuda()
            # compute output
            attention_maps, raw_features, output1 = model(input) #这个model确实输出的是这些
            features = raw_features.reshape(raw_features.shape[0], -1)

            feature_center_loss, center_diff = calculate_pooling_center_loss(
                features, state['center'], target, alfa=config.alpha) # 好像是个计算某种损失函数

            # update model.centers
            state['center'][target] += center_diff

            # compute refined loss
            # img_drop = attention_drop(attention_maps,input)
            # img_crop = attention_crop(attention_maps, input)
            img_crop, img_drop = attention_crop_drop(attention_maps, input)
            _, _, output2 = model(img_drop)
            _, _, output3 = model(img_crop)

            loss1 = criterion(output1, target)
            loss2 = criterion(output2, target)
            loss3 = criterion(output3, target)

            loss = (loss1+loss2+loss3)/3 + feature_center_loss
            # measure accuracy and record loss
            prec1, prec5 = accuracy(output1, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward() # 反向计算grad
            optimizer.step() # 利用optim进行优化

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # 每100次打印一次
            if i % print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch, i, len(train_loader), batch_time=batch_time,
                        data_time=data_time, loss=losses, top1=top1, top5=top5))
                print("loss1,loss2,loss3,feature_center_loss", loss1.item(), loss2.item(), loss3.item(),
                    feature_center_loss.item())
        return top1.avg, losses.avg
    
    def validate(self,state):
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        
        config = state['config']
        print_freq = config.print_freq
        model = state['model']
        val_loader = state['val_loader']
        criterion = state['criterion']
        # switch to evaluate mode
        model.eval()
        with torch.no_grad():
            end = time.time()
            for i, (input, target) in enumerate(val_loader):
                target = target.cuda()
                input = input.cuda()
                # forward
                attention_maps, raw_features, output1 = model(input)
                features = raw_features.reshape(raw_features.shape[0], -1)
                feature_center_loss, _ = calculate_pooling_center_loss(
                    features, state['center'], target, alfa=config.alpha)

                img_crop, img_drop = attention_crop_drop(attention_maps, input)
                # img_drop = attention_drop(attention_maps,input)
                # img_crop = attention_crop(attention_maps,input)
                _, _, output2 = model(img_drop)
                _, _, output3 = model(img_crop)
                loss1 = criterion(output1, target)
                loss2 = criterion(output2, target)
                loss3 = criterion(output3, target)
                # loss = loss1 + feature_center_loss
                loss = (loss1+loss2+loss3)/3+feature_center_loss
                # measure accuracy and record loss
                prec1, prec5 = accuracy(output1, target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(prec1[0], input.size(0))
                top5.update(prec5[0], input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % print_freq == 0:
                    print('Test: [{0}/{1}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                            i, len(val_loader), batch_time=batch_time, loss=losses,
                            top1=top1, top5=top5))

            print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
                .format(top1=top1, top5=top5))

        return top1.avg, losses.avg

    def test(self,val_loader, model, criterion):
        top1 = AverageMeter()
        top5 = AverageMeter()
        print_freq = 100
        # switch to evaluate mode
        model.eval()
        with torch.no_grad():
            for i, (input, target) in enumerate(val_loader):
                target = target.cuda()
                input = input.cuda()
                # forward
                attention_maps, _, output1 = model(input) # p1
                refined_input = mask2bbox(attention_maps, input) # crop的结果p2,这里的mask2bbox和train中的crop无差别
                _, _, output2 = model(refined_input)
                output = (F.softmax(output1, dim=-1)+F.softmax(output2, dim=-1))/2
                # measure accuracy and record loss
                prec1, prec5 = accuracy(output, target, topk=(1, 5))
                top1.update(prec1[0], input.size(0))
                top5.update(prec5[0], input.size(0))

                if i % print_freq == 0:
                    print('Test: [{0}/{1}]\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                            i, len(val_loader),
                            top1=top1, top5=top5))

            print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
                .format(top1=top1, top5=top5))

        return top1.avg, top5.avg

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Data_Designer

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值