import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
# 定义AlexNet网络
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(48, 128, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(128, 192, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
def forward(self, x):
feature_maps = []
for layer in self.features:
x = layer(x)
if isinstance(layer, nn.Conv2d):
feature_maps.append(x)
return feature_maps
# 初始化权重函数(这里使用预训练权重,因此不必调用这个函数)
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# 特征可视化函数
# 特征可视化函数
def visualize_features(img, model, save_dir):
feature_maps = model(img)
# 定义每层特征图的排列格式
formats = [(3, 16), (8, 16), (12, 16), (12, 16), (8, 16)]
# 遍历每一层的特征图
for idx, (fm, fmt) in enumerate(zip(feature_maps, formats), 1):
n_rows, n_cols = fmt
figure, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
for i, ax in enumerate(axs.flat):
if i < fm.size(1): # 检查是否在特征图数范围内
ax.imshow(fm[0, i].cpu().detach().numpy(), cmap='viridis')
ax.axis('off')
# 调整布局
plt.subplots_adjust(wspace=0.02, hspace=0.02)
# 保存图像
plt.savefig(f'{save_dir}/feature_map_layer_{idx}.png', bbox_inches='tight', pad_inches=0)
plt.clf()
plt.close(figure)
# 读取图像并进行预处理
img_path=r"E:\Python_Project\deep-learning-for-image-processing-master" \
r"\data_set\flower_data\train\dandelion\15987457_49dc11bf4b.jpg"
img = Image.open(img_path)
transform = transforms.Compose([
transforms.Resize((227, 227)),
transforms.ToTensor(),
])
img = transform(img).unsqueeze(0) # 在第一个维度上增加一个维度
# 确保保存特征图图片的文件夹存在,如果不存在则创建
save_dir = 'feature_maps'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 创建AlexNet模型实例
model = AlexNet()
model.eval() # 设置为评估模式
# 可视化特征图
visualize_features(img, model, save_dir)
10-19
1333
01-19
6万+
05-21
1万+
06-07