import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(48, 128, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(128, 192, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
def forward(self, x):
feature_maps = []
for layer in self.features:
x = layer(x)
if isinstance(layer, nn.Conv2d):
feature_maps.append(x)
return feature_maps
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def visualize_features(img, model, save_dir):
feature_maps = model(img)
formats = [(3, 16), (8, 16), (12, 16), (12, 16), (8, 16)]
for idx, (fm, fmt) in enumerate(zip(feature_maps, formats), 1):
n_rows, n_cols = fmt
figure, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
for i, ax in enumerate(axs.flat):
if i < fm.size(1):
ax.imshow(fm[0, i].cpu().detach().numpy(), cmap='viridis')
ax.axis('off')
plt.subplots_adjust(wspace=0.02, hspace=0.02)
plt.savefig(f'{save_dir}/feature_map_layer_{idx}.png', bbox_inches='tight', pad_inches=0)
plt.clf()
plt.close(figure)
img_path=r"E:\Python_Project\deep-learning-for-image-processing-master" \
r"\data_set\flower_data\train\dandelion\15987457_49dc11bf4b.jpg"
img = Image.open(img_path)
transform = transforms.Compose([
transforms.Resize((227, 227)),
transforms.ToTensor(),
])
img = transform(img).unsqueeze(0)
save_dir = 'feature_maps'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model = AlexNet()
model.eval()
visualize_features(img, model, save_dir)