conv1_weights = list(model_saved.parameters())[0]
conv1_images = make_grid(conv1_weights, normalize=True).cpu()
plt.figure(figsize=(8, 8))
plt.imshow(conv1_images.permute(1, 2, 0).numpy())
输出:可视化了第一个卷积层的卷积核
可以一步一步的来看:
第一步:
conv1_weights = list(model_saved.parameters())[0]
print(conv1_weights)
tensor([[[[ 1.1864e-01, 9.4069e-02, 9.5435e-02, ..., 5.5822e-02,
2.1575e-02, 4.9963e-02],
[ 7.4882e-02, 3.8940e-02, 5.2979e-02, ..., 2.5709e-02,
-1.1299e-02, 4.1590e-03],
[ 7.5425e-02, 3.8779e-02, 5.4930e-02, ..., 4.3596e-02,
1.0225e-02, 1.3251e-02],
...,
[ 9.3155e-02, 1.0374e-01, 6.7547e-02, ..., -2.0277e-01,
-1.2839e-01, -1.1220e-01],
[ 4.3544e-02, 6.4916e-02, 3.6164e-02, ..., -2.0248e-01,
-1.1376e-01, -1.0719e-01],
[ 4.7369e-02, 6.2543e-02, 2.4758e-02, ..., -1.1844e-01,
-9.5567e-02, -8.3890e-02]],
[[-7.2634e-02, -5.7996e-02, -8.0661e-02, ..., -6.0304e-04,
-2.5309e-02, 2.5471e-02],
[-6.9042e-02, -6.7562e-02, -7.6367e-02, ..., -3.9616e-03,
-3.0402e-02, 1.0477e-02],
[-9.9517e-02, -8.5592e-02, -1.0521e-01, ..., -2.6587e-02,
-2.2777e-02, 6.6451e-03],
...,
[-1.5121e-01, -8.8735e-02, -9.6737e-02, ..., 3.0853e-01,
1.8096e-01, 8.4297e-02],
[-1.4309e-01, -7.5710e-02, -7.2215e-02, ..., 2.0417e-01,
1.6447e-01, 9.5166e-02],
[-8.5925e-02, -4.0134e-02, -5.1491e-02, ..., 1.6352e-01,
1.4822e-01, 1.0196e-01]],
[[-2.3596e-02, -2.1258e-03, -2.7761e-02, ..., 3.9940e-02,
-7.1123e-03, 3.2207e-02],
[ 2.5705e-04, 2.2468e-02, 8.9070e-03, ..., 1.8771e-02,
-1.4155e-02, 1.8275e-02],
[ 5.4084e-03, 2.9397e-02, 3.3051e-04, ..., 1.2054e-02,
-2.5237e-03, 8.3515e-03],
...,
[-6.2826e-02, -1.1655e-02, -6.2080e-02, ..., 1.0332e-01,
-9.4987e-03, -7.9570e-02],
[-4.5691e-02, 3.3726e-03, -3.9632e-02, ..., -2.6448e-02,
-3.3500e-02, -7.6398e-02],
[-1.8700e-02, 1.1365e-02, -3.9671e-02, ..., -6.8563e-02,
-4.1289e-02, -5.5473e-02]]],
[[[-1.9950e-03, 2.9262e-03, 4.8212e-02, ..., 6.1402e-02,
2.6121e-02, 1.9558e-02],
[-1.2579e-02, -4.8879e-03, 1.8490e-02, ..., 5.3881e-02,
1.6377e-02, 2.3768e-02],
[ 3.6561e-03, -7.7510e-04, 2.6360e-02, ..., -2.5849e-02,
-6.1798e-02, 2.6103e-02],
...,
[-1.0812e-02, -4.6008e-03, 1.5122e-02, ..., 2.9561e-02,
5.3272e-03, 6.8561e-02],
[ 2.7364e-04, -1.4850e-02, 7.8180e-03, ..., 2.7172e-02,
-1.8095e-02, 5.2485e-02],
[-5.2470e-02, -4.6578e-02, -1.0951e-02, ..., 4.3038e-03,
-2.6379e-03, 1.4406e-02]],
[[ 2.3965e-02, 2.2740e-02, 5.7586e-03, ..., 7.2087e-03,
-2.4652e-02, 4.4658e-02],
[ 2.6914e-02, 4.4892e-02, -1.0872e-03, ..., 4.4243e-02,
-2.1168e-02, 6.4538e-02],
[ 1.2421e-02, 1.0247e-02, -4.1554e-02, ..., -1.2134e-01,
-1.6294e-01, 2.6266e-02],
...,
[ 3.5926e-02, 5.3235e-02, 1.1016e-02, ..., 1.2710e-02,
-2.9737e-02, 8.5926e-02],
[ 1.5623e-02, 2.1743e-02, -8.2941e-03, ..., -3.2744e-03,
-5.4099e-02, 5.7634e-02],
[ 7.5254e-02, 8.7784e-02, 5.5804e-02, ..., 5.2849e-02,
1.0612e-02, 9.3531e-02]],
[[-3.6488e-02, 6.6332e-03, -3.9035e-02, ..., -1.5678e-02,
-7.9994e-02, -8.8658e-04],
[-5.1740e-03, 5.7395e-02, 8.9841e-03, ..., 7.4166e-02,
-3.1792e-03, 4.2777e-02],
[-7.9446e-02, -2.2924e-02, -7.3370e-02, ..., -5.6738e-02,
-1.2923e-01, 1.8896e-02],
...,
[-3.9394e-02, 3.0981e-02, -2.7901e-02, ..., -1.6774e-02,
-1.0236e-01, 4.0128e-02],
[-6.0751e-02, -2.3034e-02, -7.6838e-02, ..., -7.9069e-02,
-1.6195e-01, -1.3746e-02],
[ 7.9522e-03, 4.6969e-02, -1.2460e-02, ..., -4.6956e-02,
-1.0082e-01, 1.9832e-02]]],
[[[-5.1702e-02, 1.3825e-02, 9.0514e-03, ..., -9.6401e-02,
-1.1277e-01, -2.1596e-01],
[-9.0091e-02, -1.3136e-02, -3.2812e-02, ..., -7.5263e-02,
-1.4803e-01, -2.9966e-01],
[-1.3155e-01, -4.2686e-02, -4.7744e-02, ..., 2.1429e-01,
3.2543e-02, -1.7151e-01],
...,
[-1.0621e-01, -9.7966e-02, -2.5551e-01, ..., 1.2277e-01,
1.9287e-01, 1.2671e-01],
[-8.0761e-02, -6.1498e-02, -2.2312e-01, ..., 3.5376e-02,
1.0532e-01, 1.0669e-01],
[ 3.8186e-02, 4.9957e-02, -1.2802e-01, ..., -3.2927e-02,
1.8685e-02, 4.7146e-02]],
[[ 3.9013e-02, 6.4311e-03, -3.1710e-03, ..., -2.1245e-02,
4.0516e-02, 1.1092e-01],
[ 6.5689e-02, 2.2132e-02, 6.6539e-03, ..., -3.9448e-02,
2.7749e-02, 1.1404e-01],
[ 7.7954e-02, 4.0220e-02, 1.4047e-02, ..., -1.5417e-01,
-9.2291e-02, 3.4460e-02],
...,
[ 1.2836e-01, 9.4449e-02, 1.4659e-01, ..., -6.0067e-02,
-9.0891e-02, -6.1129e-02],
[ 1.2683e-01, 1.0044e-01, 1.3754e-01, ..., -2.2507e-02,
-6.6664e-02, -1.9906e-02],
[ 8.0509e-02, 7.8203e-02, 9.8934e-02, ..., 9.2865e-03,
-3.4635e-02, -1.2395e-02]],
[[ 1.1535e-02, -2.6993e-02, 1.4820e-02, ..., 9.4833e-02,
1.2044e-01, 1.1027e-01],
[ 9.2629e-03, -2.6680e-02, 1.2218e-02, ..., 8.7219e-02,
1.5435e-01, 1.8049e-01],
[ 6.9946e-02, 1.3250e-02, 4.8007e-02, ..., -5.6851e-02,
3.2596e-02, 1.6812e-01],
...,
[-1.2187e-02, -3.3265e-02, 1.1284e-01, ..., -6.7740e-02,
-1.0240e-01, -7.6188e-02],
[-6.0069e-03, -2.8631e-02, 1.1643e-01, ..., -6.7597e-03,
-4.3772e-02, -3.1101e-02],
[-1.3355e-01, -1.4825e-01, -1.0060e-03, ..., 1.8809e-02,
-6.4637e-03, -2.7061e-02]]],
...,
[[[ 9.0948e-03, 1.4823e-02, 4.7374e-03, ..., 1.5540e-02,
-5.8369e-04, -1.9922e-02],
[ 2.8962e-04, 2.1229e-02, -1.3210e-02, ..., 2.4388e-03,
-5.8485e-03, -2.0373e-02],
[-1.1050e-02, 1.0094e-02, -2.9625e-02, ..., -1.4471e-02,
-1.7187e-02, -3.0534e-02],
...,
[ 1.0013e-01, 9.1407e-02, 1.3077e-01, ..., 1.5798e-01,
9.0361e-02, 7.8365e-02],
[ 1.1610e-01, 8.1846e-02, 8.2892e-02, ..., -6.0174e-02,
-6.9412e-02, -5.0151e-02],
[-1.0564e-01, -1.1848e-01, -1.7681e-01, ..., -2.0837e-01,
-1.8036e-01, -1.6691e-01]],
[[-1.1495e-02, 2.4917e-03, -8.2400e-03, ..., -7.5865e-03,
-1.7387e-02, -1.7026e-02],
[-2.7461e-03, -1.1415e-02, -6.0670e-03, ..., -2.8248e-02,
-2.2555e-02, -2.2559e-02],
[-8.4246e-03, -2.2378e-03, -3.5825e-02, ..., -1.8462e-02,
-1.9795e-02, -2.4990e-02],
...,
[ 1.2956e-01, 9.8057e-02, 1.4888e-01, ..., 1.5655e-01,
7.9547e-02, 9.6928e-02],
[ 1.6080e-01, 1.0522e-01, 1.0264e-01, ..., -6.5226e-02,
-6.4314e-02, -3.9066e-02],
[-1.2880e-01, -1.4656e-01, -1.9475e-01, ..., -2.4177e-01,
-2.0277e-01, -1.9324e-01]],
[[-5.4209e-03, -1.7648e-03, 4.3163e-03, ..., 9.8182e-03,
5.2347e-03, 6.3336e-03],
[ 9.3478e-03, 1.5920e-03, -2.1660e-03, ..., 7.1244e-03,
9.3521e-04, -5.7647e-03],
[-1.3668e-03, 3.8899e-03, -7.7412e-03, ..., 3.0651e-03,
1.4989e-02, -8.1730e-03],
...,
[ 4.3937e-02, 7.9544e-04, 6.0613e-02, ..., 7.4787e-02,
4.3960e-02, 5.6071e-02],
[ 1.0077e-01, 7.5126e-02, 1.0962e-01, ..., 4.9886e-03,
1.0789e-02, 1.3402e-02],
[-8.8773e-02, -7.2790e-02, -9.2907e-02, ..., -6.6530e-02,
-3.8977e-02, -4.8370e-02]]],
[[[ 4.5712e-03, 4.6908e-02, -1.6075e-02, ..., 7.7790e-03,
-1.9798e-02, 6.8000e-03],
[ 6.2769e-02, 4.5108e-02, 4.7187e-02, ..., 6.0599e-02,
2.9275e-02, 5.5794e-02],
[ 3.6748e-03, 1.2952e-02, 1.8988e-05, ..., -8.3492e-03,
-1.9689e-03, 8.0830e-03],
...,
[-4.4365e-02, -5.8858e-02, -2.4772e-02, ..., -2.8423e-02,
-3.0897e-02, -5.2936e-02],
[-9.7572e-03, -4.3227e-02, 9.0068e-03, ..., -4.2596e-02,
-1.8114e-02, -2.8028e-02],
[-2.2007e-02, -3.3594e-02, 1.3479e-02, ..., -4.1530e-02,
-1.7819e-02, -5.1977e-02]],
[[-9.0596e-02, -5.1485e-02, -1.6459e-01, ..., -1.1970e-01,
-1.1150e-01, -4.3914e-02],
[ 1.3835e-02, 2.6007e-02, -1.9440e-02, ..., 2.3757e-02,
6.3709e-03, 5.4287e-02],
[-9.3225e-02, -4.7454e-02, -1.1274e-01, ..., -8.6459e-02,
-7.3958e-02, -6.6610e-02],
...,
[ 2.7382e-02, 1.0310e-02, 4.3906e-02, ..., 2.7094e-02,
4.4510e-02, 1.5977e-02],
[ 9.8431e-02, 6.1433e-02, 1.1413e-01, ..., 9.6398e-02,
1.0725e-01, 9.5719e-02],
[-1.5100e-02, -1.1830e-02, 4.8571e-02, ..., 2.9060e-02,
5.6323e-02, -2.1631e-03]],
[[-1.3693e-01, -7.9341e-02, -2.1245e-01, ..., -1.3633e-01,
-1.5123e-01, -6.3938e-02],
[ 1.5745e-02, 5.1443e-02, -1.8209e-02, ..., 4.9085e-02,
1.9585e-02, 7.8095e-02],
[-1.7127e-01, -8.8734e-02, -1.7467e-01, ..., -1.4431e-01,
-1.3364e-01, -1.1878e-01],
...,
[ 5.1982e-02, 1.5912e-02, 7.0009e-02, ..., 4.4396e-02,
6.5429e-02, 2.6919e-02],
[ 1.3144e-01, 9.1333e-02, 1.6228e-01, ..., 1.4230e-01,
1.5500e-01, 1.3808e-01],
[ 3.7818e-03, -2.0365e-02, 7.8663e-02, ..., 8.6037e-02,
1.2596e-01, 3.8774e-02]]],
[[[-9.5081e-02, 5.6581e-02, 1.4016e-01, ..., -1.4826e-02,
1.0425e-02, -3.7919e-03],
[ 6.0412e-02, 1.2531e-01, -1.2366e-01, ..., 7.8628e-02,
-1.0564e-02, -2.2228e-02],
[ 1.0304e-01, -1.3654e-01, -1.9601e-01, ..., -5.1919e-02,
-8.3287e-02, 3.9678e-02],
...,
[-1.8598e-01, 1.5479e-02, 3.3912e-01, ..., 2.7509e-01,
1.0493e-01, -1.6863e-01],
[ 6.8278e-02, 1.3989e-01, 1.1350e-02, ..., -7.8728e-02,
-2.5001e-01, -8.0072e-02],
[ 5.1736e-03, -6.8372e-02, -9.1940e-02, ..., -1.0727e-01,
8.7973e-02, 1.0139e-01]],
[[-9.8760e-02, 6.2635e-02, 1.1495e-01, ..., -1.3191e-02,
1.9574e-02, -1.0864e-03],
[ 8.3206e-02, 1.0300e-01, -1.5419e-01, ..., 1.0864e-01,
1.2592e-02, -2.0278e-02],
[ 1.1248e-01, -1.6928e-01, -1.8876e-01, ..., -6.3491e-02,
-9.3101e-02, 4.9006e-02],
...,
[-1.8867e-01, 6.7374e-02, 4.8128e-01, ..., 3.1371e-01,
1.6068e-01, -1.3449e-01],
[ 1.1207e-01, 2.0871e-01, 4.6815e-02, ..., -7.0456e-03,
-2.2978e-01, -6.9528e-02],
[ 1.7255e-02, -9.1683e-02, -1.5943e-01, ..., -8.0373e-02,
8.0035e-02, 1.1819e-01]],
[[-1.0580e-01, 5.4381e-02, 1.3048e-01, ..., -3.9546e-02,
1.1924e-02, -2.7886e-03],
[ 5.7859e-02, 1.0646e-01, -1.3468e-01, ..., 1.0645e-01,
1.8412e-02, -1.3245e-03],
[ 1.0398e-01, -1.0849e-01, -1.6758e-01, ..., -4.2554e-02,
-8.5875e-02, 5.2581e-02],
...,
[-1.8241e-01, 5.4228e-02, 3.9395e-01, ..., 2.4432e-01,
1.0216e-01, -1.2876e-01],
[ 8.5720e-02, 1.8378e-01, 5.0316e-02, ..., -4.7009e-02,
-2.1584e-01, -4.2482e-02],
[ 3.9612e-02, -7.6353e-02, -1.3502e-01, ..., -4.8630e-02,
1.0063e-01, 9.0307e-02]]]], device='cuda:0')
第二步:
特征图的每个像素点上的数值范围不是[0,1],而是可正可负,可大可小,因此需要做一些特殊处理。这里就要用到 torchvision.utils.make_grid( )函数,把输入的特征图做一个归一化,把参数normalize设置为True即可,它能帮我们把数据的输入范围调整至[0, 1]之间
- make_grid() 输入的是Tensor,而不是numpy.ndarray
- torchvision.utils.make_grid() 将一组图片绘制到一个窗口,其本质是将一组图片拼接成一张图片
conv1_images = make_grid(conv1_weights, normalize=True).cpu()
print(conv1_images)
tensor([[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.5244, ..., 0.4508, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.3961, ..., 0.5143, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.4130, ..., 0.4338, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.4496, ..., 0.5241, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.4416, ..., 0.3990, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.4963, ..., 0.5079, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])
最后
imshow()的输入参数:
真彩色图像输入,指定为 m×n×3 数组。如果指定数据类型为 single 或 double 的真彩色图像,则值应在 [0, 1] 范围内。如果像素值超出此范围,则可以使用rescale函数将像素值缩放到范围 [0, 1] 内。当输入图像为真彩色时,'DisplayRange' 参数无效。
数据类型: single
| double
| int8
| int16
| int32
| int64
| uint8
| uint16
| uint32
| uint64
| logical
plt.figure(figsize=(8, 8))
plt.imshow(conv1_images.permute(1, 2, 0).numpy())
输出
PyTorch 实现的 AlexNet 第一层卷积核参数的形状是 64×3×11×1164×3×11×11 的四维 Tensor,这样就可以得到上述 6464 个11×1111×11 的图片块了。显然,这些重构出来的图像基本都是关于边缘,条纹以及颜色的信息
以上为卷积核的可视化接下来看看卷积层的可视化:
def visualize(alexnet, input_data, submodule_name, layer_index):
'''
alexnet: 模型
input_data: 输入数据
submodule_name: 可视化 module 的 name, 专门针对 nn.Sequential
layer_index: 在 submodule 中的 index
'''
x = input_data
modules = alexnet._modules
for name in modules.keys():
if (name == submodule_name):
module_layers = list(modules[name].children())
for i in range(layer_index+1):
if (type(module_layers[i]) == torch.nn.Linear):
x = x.reshape(x.size(0), -1) # 针对线性层
x = module_layers[i](x)
return x
x = modules[name](x)
feature_maps = visualize(model_saved, IMAGE.to(dev), 'features', 0)
feature_images = make_grid(feature_maps.permute(1, 0, 2, 3), normalize=True).cpu()
plt.figure(figsize=(8, 8))
plt.imshow(feature_images.permute(1, 2, 0).numpy())
输出可视化的卷积层: