PyTorch Swin-Transformer 各层特征可视化

PyTorch相关开源库
https://gitee.com/hejuncheng1/pytorch-grad-cam

安装命令

pip install grad-cam

具体使用参考
Swin Transformer各层特征可视化_不高兴与没头脑Fire的博客-CSDN博客

提供示例

# dataloader.py
from torchvision import datasets, transforms
import os
import torch

input_size = 224

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'test': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}


def update(new_input_size):
    global input_size
    global data_transforms

    input_size = new_input_size

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'test': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    }


def dataloader(data_dir, batch_size, set_name, shuffle):
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [set_name]}
    num_workers = 1 if torch.cuda.is_available() else 0
    dataset_loaders = {
        x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=shuffle,
                                       num_workers=num_workers)
        for x in [set_name]}
    dataset_sizes = len(image_datasets[set_name])
    return dataset_loaders, dataset_sizes


if __name__ == '__main__':
    data_dir = ''
    dset_loaders, dset_sizes = dataloader(data_dir=data_dir, batch_size=16, set_name='train', shuffle=True)
    print(dset_loaders, dset_sizes)
# main.py
import cv2
import numpy as np
import torch
import torch.nn as nn
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image

import dataloader


def reshape_transform(tensor, height=12, width=12):
    result = tensor.reshape(tensor.size(0),
                            height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result


if __name__ == '__main__':
    net_name = 'swin_base_patch4_window12_384_22k'
    categories_size = 2
    model_ft = None

    if net_name == 'swin_base_patch4_window12_384_22k':
        from models import swintf

        model_ft = swintf.build_model('config/swin_base_patch4_window12_384_22k.yaml', use_checkpoint=True)
        model_ft.head = nn.Linear(1024, categories_size)
        dataloader.update(384)

    use_gpu = True if torch.cuda.is_available() else False
    if use_gpu:
        model_ft = model_ft.cuda()

    load_path = os.path.join('./save', net_name + '.pth')
    if os.path.exists(load_path):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        msg = model_ft.load_state_dict(torch.load(load_path, map_location=device))
    print('msg:', msg)

    model_ft.eval()

    target_layer = [model_ft.norm]

    target_category = 0
    image_path = ''
    image = Image.open(image_path)
    transformer = dataloader.data_transforms['test']
    image_ = transformer(image)
    inputs = image_.unsqueeze(0)

    cam = GradCAM(model=model_ft, target_layers=target_layer, use_cuda=False, reshape_transform=reshape_transform)
    cam.batch_size = 1
    grayscale_cam = cam(input_tensor=inputs, target_category=target_category, eigen_smooth=True,
                        aug_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    image = np.array(image.resize((384, 384))) / 255.0
    cam_image = show_cam_on_image(image, grayscale_cam)
    cv2.imwrite('cam.jpg', cam_image)
    print('OK')
  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
引用\[1\]提供了PyTorch和TensorFlow2中实现Swin-Transformer的代码。而引用\[2\]给出了Swin-Transformer图像分割的GitHub地址。如果你想修改PyTorch内置的Swin-Transformer的通道数,你可以按照以下步骤进行操作: 1. 首先,确保你已经安装了PyTorch和相关的依赖库。 2. 下载Swin-TransformerPyTorch实现代码。你可以在GitHub上找到相关的代码仓库。 3. 打开Swin-Transformer的代码文件,找到与通道数相关的部分。通常,这些部分会涉及到模型的定义或者卷积层的设置。 4. 根据你的需求,修改相应的通道数。你可以增加或减少通道数,但要确保修改后的通道数与模型的其他部分保持一致。 5. 保存修改后的代码文件,并重新运行你的程序。 需要注意的是,修改通道数可能会对模型的性能和效果产生影响,因此建议在修改之前先进行一些实验和测试,以确保修改后的模型仍然具有良好的性能。 希望这个回答对你有帮助! #### 引用[.reference_title] - *1* [Swin-Transformer网络结构详解](https://blog.csdn.net/qq_37541097/article/details/121119988)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [Swin-Transformer 图像分割实战:使用Swin-Transformer-Semantic-Segmentation训练ADE20K数据集(语义分割...](https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/121904901)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值