pytorch可视化CNN每层的特征

在PyTorch中,可以使用torchvision.utils.make_grid来将特征图可视化为一个网格。具体步骤如下:
1.定义一个数据集并加载数据

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                 transform=transforms.ToTensor(), download=True)
# 加载数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

2.加载预训练模型并定义一个函数,用于提取每一层的特征

import torch.nn as nn
import torchvision.models as models

# 加载预训练模型
model = models.resnet18(pretrained=True)

# 定义一个函数,用于提取每一层的特征
def get_features(x, model, layers):
    features = []
    for name, module in model._modules.items():
        x = module(x)
        if name in layers:
            features.append(x)
    return features

3.获取特定层的名称并将数据集中的一批数据输入到模型中,获取对应层的特征

# 获取模型中所有层的名称
all_layers = []
for name, layer in model.named_modules():
    all_layers.append(name)

# 获取需要可视化的层的名称
layers = all_layers[4:9]

# 获取一批数据
data, _ = next(iter(train_loader))

# 将数据输入到模型中,并获取对应层的特征
features = get_features(data, model, layers)

4.将特征可视化为网格

import matplotlib.pyplot as plt
import numpy as np

# 定义一个函数,用于将特征可视化为网格
def visualize_features(features):
    nrow = len(features)
    ncols = features[0].shape[1]
    fig, axs = plt.subplots(nrow, ncols, figsize=(10, 10))

    for i in range(nrow):
        for j in range(ncols):
            img = features[i][0][j].detach().numpy()
            img = np.transpose(img, (1, 2, 0))
            img = (img - img.min()) / (img.max() - img.min())
            axs[i][j].imshow(img)
            axs[i][j].axis('off')
            if j == 0:
                axs[i][j].set_title(layers[i])
    plt.show()

# 可视化特征
visualize_features(features)

以上代码会将提取的特定层的特征可视化为一个网格,并在网格的左侧显示对应层的名称。

  • 0
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
好的,下面是实现的步骤: 1. 首先需要定义自己的网络模型,并加载预训练好的权重。 2. 导入需要用到的库,包括pytorch、numpy、matplotlib等。 3. 定义一个函数来获取任意层的特征。该函数需要输入一个图片和目标层的名称,然后输出该层的特征。 4. 定义一个函数来进行tsne可视化。该函数需要输入一个特征矩阵和标签,然后输出可视化结果。 下面是代码示例: ```python import torch import numpy as np import matplotlib.pyplot as plt from sklearn.manifold import TSNE # 定义自己的网络模型 class MyNet(torch.nn.Module): def __init__(self): super(MyNet, self).__init__() self.conv1 = torch.nn.Conv2d(3, 64, 3, padding=1) self.conv2 = torch.nn.Conv2d(64, 128, 3, padding=1) self.conv3 = torch.nn.Conv2d(128, 256, 3, padding=1) self.fc1 = torch.nn.Linear(256*8*8, 1024) self.fc2 = torch.nn.Linear(1024, 10) def forward(self, x): x = torch.nn.functional.relu(self.conv1(x)) x = torch.nn.functional.relu(self.conv2(x)) x = torch.nn.functional.relu(self.conv3(x)) x = x.view(-1, 256*8*8) x = torch.nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x # 加载预训练好的权重 model = MyNet() model.load_state_dict(torch.load('model.pth')) # 获取任意层的特征 def get_feature(image, layer_name): model.eval() features = [] def hook(module, input, output): features.append(output.detach().numpy()) layer = model._modules.get(layer_name) handle = layer.register_forward_hook(hook) model(image) handle.remove() feature = np.concatenate(features, axis=0) return feature # 进行tsne可视化 def tsne_visualization(feature, label): tsne = TSNE(n_components=2, init='pca', random_state=0) feature_tsne = tsne.fit_transform(feature) plt.scatter(feature_tsne[:,0], feature_tsne[:,1], c=label) plt.show() # 加载图片并进行tsne可视化 image = torch.randn(1, 3, 32, 32) feature = get_feature(image, 'conv2') label = [1] tsne_visualization(feature, label) ``` 你可以根据自己的需要,修改自己的网络模型和目标层名称,以及输入的图片和标签。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

酒与花生米

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值