import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import resnet50
from net import resnet50
import sys
# sys.path.append('/gly/ab/WSSS/WSSS/ReCAM+CCAM/')
from misc import torchutils
import misc.torchutils
from torchvision import transforms
from net import loss
class Net(nn.Module):
def __init__(self, stride=16, n_classes=20):
super(Net, self).__init__()
if stride == 16:
self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 2, 1))
self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool,self.resnet50.layer1)
else:
self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 1, 1), dilations=(1, 1, 2, 2))
self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool,self.resnet50.layer1)
self.stage2 = nn.Sequential(self.resnet50.layer2)
self.stage3 = nn.Sequential(self.resnet50.layer3)
self.stage4 = nn.Sequential(self.resnet50.layer4)
self.n_classes = n_classes
self.classifier = nn.Conv2d(2048, n_classes, 1, bias=False)
self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4])
self.newly_added = nn.ModuleList([self.classifier])
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = torchutils.gap2d(x, keepdims=True)
x = self.classifier(x)
x = x.view(-1, self.n_classes)
return x
def train(self, mode=True):
super(Net, self).train(mode)
for p in self.resnet50.conv1.parameters():
p.requires_grad = False
for p in self.resnet50.bn1.parameters():
p.requires_grad = False
def trainable_parameters(self):
return (list(self.backbone.parameters()), list(self.newly_added.parameters()))
class Net_CAM(Net):
def __init__(self,stride=16,n_classes=20):
super(Net_CAM, self).__init__(stride=stride,n_classes=n_classes)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
feature = self.stage4(x)
x = torchutils.gap2d(feature, keepdims=True)
x = self.classifier(x)
x = x.view(-1, self.n_classes)
cams = F.conv2d(feature, self.classifier.weight)
cams = F.relu(cams)
return x,cams,feature
class Net_CAM_Feature(Net):
def __init__(self,stride=16,n_classes=20):
super(Net_CAM_Feature, self).__init__(stride=stride,n_classes=n_classes)
def forward(self, x):
# x = 16, 3, 512, 512
# for i in x:
# print(i.size()) 3, 512, 512
x = self.stage1(x)
# 16, 256, 128, 128
x = self.stage2(x)
# 16, 512, 64, 64
x = self.stage3(x)
# 16, 1024, 32, 32
feature = self.stage4(x) # bs*2048*32*32
# 16, 1024, 32, 32
x = torchutils.gap2d(feature, keepdims=True)
# 16, 2048, 1, 1
x = self.classifier(x)
# 16, 20, 1, 1
x = x.view(-1, self.n_classes)
# 16, 20
cams = F.conv2d(feature, self.classifier.weight)
cams = F.relu(cams)
cams = cams/(F.adaptive_max_pool2d(cams, (1, 1)) + 1e-5)
# cams.size() # 16, 20, 32, 32
cams_feature = cams.unsqueeze(2)*feature.unsqueeze(1) # bs*20*2048*32*32
cams_feature = cams_feature.view(cams_feature.size(0),cams_feature.size(1),cams_feature.size(2),-1)
cams_feature = torch.mean(cams_feature,-1)
# cams_feature.size() # 16, 20, 2048
return x,cams_feature,cams
class CAM(Net):
def __init__(self, stride=16,n_classes=20):
super(CAM, self).__init__(stride=stride,n_classes=n_classes)
def forward(self, x, separate=False):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = F.conv2d(x, self.classifier.weight)
if separate:
return x
x = F.relu(x)
x = x[0] + x[1].flip(-1)
return x
def forward1(self, x, weight, separate=False):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = F.conv2d(x, weight)
if separate:
return x
x = F.relu(x)
x = x[0] + x[1].flip(-1)
return x
def forward2(self, x, weight, separate=False):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = F.conv2d(x, weight*self.classifier.weight)
if separate:
return x
x = F.relu(x)
x = x[0] + x[1].flip(-1)
return x
class Class_Predictor(nn.Module):
def __init__(self, num_classes, representation_size):
super(Class_Predictor, self).__init__()
self.num_classes = num_classes
self.classifier = nn.Conv2d(representation_size, num_classes, 1, bias=False)
self.ac_head = Disentangler(self.num_classes)
# self.from_scratch_layers = [self.ac_head]
self.loss = loss
def forward(self, x, cams_feature, label, inference=False):
batch_size = x.shape[0]
x = x.reshape(batch_size,self.num_classes,32,32) # bs*20*32*32 #torch.FloatTensor
mask = label>0 # bs*20
# 每张照片n*2048,n为类别中>0的数量
contrast_feature_list = [x[i][mask[i]] for i in range(batch_size)] # bs*n*2048
# n 里面的每一个都是
# uploader = transforms.ToPILImage()
# image = feature_list[0].cpu().clone()
# image = uploader(image)
# image.save("example.jpg")
criterion = [self.loss.SimMaxLoss(metric='cos', alpha=0.25).cuda(), self.loss.SimMinLoss(metric='cos').cuda(),
self.loss.SimMaxLoss(metric='cos', alpha=0.25).cuda()]
loss_CCAM = 0
for feature in contrast_feature_list:
# 出现空张量的原因:标签中有该类别,但是cam没有识别到
# if feature.numel():
# print('.') # 非空
# else:
# print('!') # 空
# np.set_printoptions(threshold=sys.maxsize)
# print(feature)
n = feature.shape[0]
ac_head = Disentangler(n)
fg_feats, bg_feats, ccam = ac_head(feature.unsqueeze(0), inference=inference) # 单个有效,循环无效
loss1 = criterion[0](fg_feats)
loss2 = criterion[1](bg_feats, fg_feats)
loss3 = criterion[2](bg_feats)
loss_CCAM += loss1
loss_CCAM += loss2
loss_CCAM += loss3
refeature_list = [cams_feature[i][mask[i]] for i in range(batch_size)]
prediction = [self.classifier(y.unsqueeze(-1).unsqueeze(-1)).squeeze(-1).squeeze(-1) for y in refeature_list]
# prediction 16 * (n, 20)
labels = [torch.nonzero(label[i]).squeeze(1) for i in range(label.shape[0])]
# labels 16 * (20)
loss = 0
acc = 0
num = 0
for logit,label in zip(prediction, labels):
if label.shape[0] == 0:
continue
loss_ce= F.cross_entropy(logit, label)
loss += loss_ce
acc += (logit.argmax(dim=1)==label.view(-1)).sum().float()
num += label.size(0)
l = (loss + loss_CCAM)/batch_size
print(l.item())
print(".")
return l, acc/num
# x = torch.randn(16, 20, 32, 32)
# import random
#
# label = [[random.randint(0, 1) for j in range(20)] for i in range(16)]
# label = torch.tensor(label)
# # print(label)
# cams_feature = torch.randn(16, 20, 2048)
# recam_predictor = Class_Predictor(20, 2048).cuda()
# recam_predictor = torch.nn.DataParallel(recam_predictor)
# recam_predictor.train()
# loss_ce,acc = recam_predictor(x, cams_feature, label)
#
# print(loss_ce,acc)
train_cam网络输出各层shape
最新推荐文章于 2024-08-06 23:51:38 发布