根据大神的代码整理的。
链接在这
自己定义了一个LeNet,
代码如下:
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
class LeNet(nn.Module):
'''
该类继承了torch.nn.Modul类
构建LeNet神经网络模型
'''
def __init__(self):
super(LeNet, self).__init__()
# 第一层神经网络,包括卷积层、线性激活函数、池化层
self.conv1 = nn.Sequential(
nn.Conv2d(3, 8, 5, 1, 2), # input_size=(3*256*256),padding=2
nn.ReLU(), # input_size=(32*256*256)
nn.MaxPool2d(kernel_size=2, stride=2), # output_size=(32*128*128)
)
# 第二层神经网络,包括卷积层、线性激活函数、池化层
self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, 5, 1, 2), # input_size=(32*128*128)
nn.ReLU(), # input_size=(64*128*128)
nn.MaxPool2d(2, 2) # output_size=(64*64*64)
)
# 全连接层(将神经网络的神经元的多维输出转化为一维)
self.fc1 = nn.Sequential(
nn.Linear(16 * 64 * 64, 128), # 进行线性变换
nn.ReLU() # 进行ReLu激活
)
# 输出层(将全连接层的一维输出进行处理)
self.fc2 = nn.Sequential(
nn.Linear(128, 84),
nn.ReLU()
)
# 将输出层的数据进行分类(输出预测值)
self.fc3 = nn.Linear(84, 62)
# 定义前向传播过程,输入为x
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
# nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def get_picture(picture_dir, transform):
'''
该算法实现了读取图片,并将其类型转化为Tensor
'''
img = skimage.io.imread(picture_dir)
img256 = skimage.transform.resize(img, (256, 256))
img256 = np.asarray(img256)
img256 = img256.astype(np.float32)
return transform(img256)
def get_picture_rgb(picture_dir):
'''
该函数实现了显示图片的RGB三通道颜色
'''
img = skimage.io.imread(picture_dir)
img256 = skimage.transform.resize(img, (256, 256))
skimage.io.imsave('D:\code\jupyter\data/4.jpg',img256)
img = img256.copy()
ax = plt.subplot()
ax.set_title('image')
plt.imshow(img)
plt.show()
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = []
# print(self.submodule._modules.items())
for name, module in self.submodule._modules.items():
if "fc" in name:
# print(name)
x = x.view(x.size(0), -1)
# print(module)
x = module(x)
# print(name)
if name in self.extracted_layers:
outputs.append(x)
return outputs
def get_feature():
# 输入数据
img = get_picture(pic_dir, transform)
# 插入维度
img = img.unsqueeze(0)
img = img.to(device)
# 特征输出
net = LeNet().to(device)
# net.load_state_dict(torch.load('./model/net_050.pth'))
exact_list = ["conv1","conv2"]
myexactor = FeatureExtractor(net, exact_list)
x = myexactor(img)
# 特征输出可视化
for i in range(8):
ax = plt.subplot(2,4, i + 1)
ax.set_title('Feature {}'.format(i))
ax.axis('off')
plt.imshow(x[0].cpu()[0,i,:,:].detach().numpy(),cmap='jet')
# plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
plt.show()