在实时训练网络中显示CAM结果
paper: Learning Deep Features for Discriminative Localization
主要代码
def ComputeCAM(features, softmax_weight, class_ids):
class_ids = class_ids.long()
m = []
for idx in class_ids:
wei = softmax_weight[idx, :].unsqueeze(0)
m.append(wei)
tensor_m = torch.cat(m, dim=0).unsqueeze(2).unsqueeze(3)
cam = tensor_m * features
cam = cam.sum(dim=1).unsqueeze(1)
cam = (cam - cam.min()) / (cam.max() - cam.min())
weight = torch.nn.functional.interpolate(cam, scale_factor=8, mode='bilinear', align_corners=False) # reshape its size as input
return weight
def GetCam(model, inputs):
out, features = model(inputs)
_, idx = torch.max(out, 1)
params = list(model.parameters())
softmax_weight = params[-2].squeeze()
weight = ComputeCAM(features, softmax_weight, idx)
return weight
if __name__ == "__main__":
x = torch.rand(10, 3, 32, 32).to(device)
z = GetCam(net, x)
print(z.size())
完整代码链接: https://github.com/Jixianrui/CAM-with-ResNet
结果展示
时间原因,该结果为只训练了10个epoch之后的结果