train_cam网络输出各层shape

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# import resnet50
from net import resnet50
import sys
# sys.path.append('/gly/ab/WSSS/WSSS/ReCAM+CCAM/')
from misc import torchutils
import misc.torchutils

from torchvision import transforms
from net import loss

class Net(nn.Module):

    def __init__(self, stride=16, n_classes=20):
        super(Net, self).__init__()
        if stride == 16:
            self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 2, 1))
            self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool,self.resnet50.layer1)
        else:
            self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 1, 1), dilations=(1, 1, 2, 2))
            self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool,self.resnet50.layer1)
        self.stage2 = nn.Sequential(self.resnet50.layer2)
        self.stage3 = nn.Sequential(self.resnet50.layer3)
        self.stage4 = nn.Sequential(self.resnet50.layer4)
        self.n_classes = n_classes
        self.classifier = nn.Conv2d(2048, n_classes, 1, bias=False)

        self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4])
        self.newly_added = nn.ModuleList([self.classifier])


    def forward(self, x):

        x = self.stage1(x)
        x = self.stage2(x)

        x = self.stage3(x)
        x = self.stage4(x)

        x = torchutils.gap2d(x, keepdims=True)
        x = self.classifier(x)
        x = x.view(-1, self.n_classes)

        return x

    def train(self, mode=True):
        super(Net, self).train(mode)
        for p in self.resnet50.conv1.parameters():
            p.requires_grad = False
        for p in self.resnet50.bn1.parameters():
            p.requires_grad = False

    def trainable_parameters(self):

        return (list(self.backbone.parameters()), list(self.newly_added.parameters()))

class Net_CAM(Net):

    def __init__(self,stride=16,n_classes=20):
        super(Net_CAM, self).__init__(stride=stride,n_classes=n_classes)
        
    def forward(self, x):

        x = self.stage1(x)
        x = self.stage2(x)

        x = self.stage3(x)
        feature = self.stage4(x)

        x = torchutils.gap2d(feature, keepdims=True)
        x = self.classifier(x)
        x = x.view(-1, self.n_classes)

        cams = F.conv2d(feature, self.classifier.weight)
        cams = F.relu(cams)
        
        return x,cams,feature

class Net_CAM_Feature(Net):

    def __init__(self,stride=16,n_classes=20):
        super(Net_CAM_Feature, self).__init__(stride=stride,n_classes=n_classes)
        
    def forward(self, x):
        # x = 16, 3, 512, 512
        # for i in x:
        #     print(i.size()) 3, 512, 512
        x = self.stage1(x)
        # 16, 256, 128, 128
        x = self.stage2(x)
        # 16, 512, 64, 64
        x = self.stage3(x)
        # 16, 1024, 32, 32
        feature = self.stage4(x) # bs*2048*32*32
        # 16, 1024, 32, 32
        x = torchutils.gap2d(feature, keepdims=True)
        # 16, 2048, 1, 1
        x = self.classifier(x)
        # 16, 20, 1, 1
        x = x.view(-1, self.n_classes)
        # 16, 20

        cams = F.conv2d(feature, self.classifier.weight)
        cams = F.relu(cams)
        cams = cams/(F.adaptive_max_pool2d(cams, (1, 1)) + 1e-5)
        # cams.size()  # 16, 20, 32, 32
        cams_feature = cams.unsqueeze(2)*feature.unsqueeze(1) # bs*20*2048*32*32
        cams_feature = cams_feature.view(cams_feature.size(0),cams_feature.size(1),cams_feature.size(2),-1)
        cams_feature = torch.mean(cams_feature,-1)
        # cams_feature.size()  # 16, 20, 2048
        return x,cams_feature,cams


class CAM(Net):

    def __init__(self, stride=16,n_classes=20):
        super(CAM, self).__init__(stride=stride,n_classes=n_classes)

    def forward(self, x, separate=False):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = F.conv2d(x, self.classifier.weight)
        if separate:
            return x
        x = F.relu(x)
        x = x[0] + x[1].flip(-1)

        return x

    def forward1(self, x, weight, separate=False):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = F.conv2d(x, weight)
        
        if separate:
            return x
        x = F.relu(x)
        x = x[0] + x[1].flip(-1)

        return x

    def forward2(self, x, weight, separate=False):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = F.conv2d(x, weight*self.classifier.weight)
        
        if separate:
            return x
        x = F.relu(x)
        x = x[0] + x[1].flip(-1)
        return x


class Class_Predictor(nn.Module):
    def __init__(self, num_classes, representation_size):
        super(Class_Predictor, self).__init__()
        self.num_classes = num_classes
        self.classifier = nn.Conv2d(representation_size, num_classes, 1, bias=False)
        self.ac_head = Disentangler(self.num_classes)
        # self.from_scratch_layers = [self.ac_head]
        self.loss = loss
    def forward(self, x, cams_feature, label, inference=False):
        batch_size = x.shape[0]
        x = x.reshape(batch_size,self.num_classes,32,32) # bs*20*32*32 #torch.FloatTensor
        mask = label>0  # bs*20

        # 每张照片n*2048,n为类别中>0的数量
        contrast_feature_list = [x[i][mask[i]] for i in range(batch_size)] # bs*n*2048
        # n 里面的每一个都是
        # uploader = transforms.ToPILImage()
        # image = feature_list[0].cpu().clone()
        # image = uploader(image)
        # image.save("example.jpg")
        criterion = [self.loss.SimMaxLoss(metric='cos', alpha=0.25).cuda(), self.loss.SimMinLoss(metric='cos').cuda(),
                     self.loss.SimMaxLoss(metric='cos', alpha=0.25).cuda()]
        loss_CCAM = 0
        for feature in contrast_feature_list:
            # 出现空张量的原因:标签中有该类别,但是cam没有识别到
            # if feature.numel():
            #     print('.')  # 非空
            # else:
            #     print('!')  # 空
                # np.set_printoptions(threshold=sys.maxsize)
                # print(feature)
            n = feature.shape[0]
            ac_head = Disentangler(n)
            fg_feats, bg_feats, ccam = ac_head(feature.unsqueeze(0), inference=inference)  # 单个有效,循环无效
            loss1 = criterion[0](fg_feats)
            loss2 = criterion[1](bg_feats, fg_feats)
            loss3 = criterion[2](bg_feats)
            loss_CCAM += loss1
            loss_CCAM += loss2
            loss_CCAM += loss3


        refeature_list = [cams_feature[i][mask[i]] for i in range(batch_size)]
        prediction = [self.classifier(y.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) for y in refeature_list]
        # prediction 16 * (n, 20)
        labels = [torch.nonzero(label[i]).squeeze(1) for i in range(label.shape[0])]
        # labels 16 * (20)

        loss = 0
        acc = 0
        num = 0
        for logit,label in zip(prediction, labels):
            if label.shape[0] == 0:
                continue
            loss_ce= F.cross_entropy(logit, label)
            loss += loss_ce
            acc += (logit.argmax(dim=1)==label.view(-1)).sum().float()
            num += label.size(0)
            l = (loss + loss_CCAM)/batch_size
            print(l.item())
            print(".")
        return l, acc/num

# x = torch.randn(16, 20, 32, 32)
# import random
#
# label = [[random.randint(0, 1) for j in range(20)] for i in range(16)]
# label = torch.tensor(label)
# # print(label)
# cams_feature = torch.randn(16, 20, 2048)
# recam_predictor = Class_Predictor(20, 2048).cuda()
# recam_predictor = torch.nn.DataParallel(recam_predictor)
# recam_predictor.train()
# loss_ce,acc = recam_predictor(x, cams_feature, label)
#
# print(loss_ce,acc)
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Env1sage

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值