提取图片的特征图
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
for name, module in self.submodule._modules.items():
if "fc" in name:
x = x.view(x.size(0), -1)
x = module(x)
print(name)
if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
outputs[name] = x
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (224, 224))#改变图片大小
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = './ww.jpg'#输入图像
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 插入维度
img = img.unsqueeze(0)
img = img.to(device)
net = models.resnet18().to(device)
net.load_state_dict(torch.load('./resnet18-f37072fd.pth'))
exact_list = None
dst = './ww'#特征提取后的输出文件夹
therd_size = 256
# print (net)
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
# 存储图片,为每个层创建一个文件夹将特征图以JET的colormap进行按顺序存储到该文件夹,如果特征图过小也会对特征图放大同时存储原始图和放大后的图。
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
# plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.numpy()
feature_img = feature[i, :, :]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET) #将特征图以JET的colormap进行按顺序存储到该文件夹、
# 特征图过小也会对特征图放大同时存储原始图和放大后的图。
# if feature_img.shape[0] < therd_size:
# tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
# tmp_img = feature_img.copy()
# tmp_img = cv2.resize(tmp_img, (therd_size, therd_size), interpolation=cv2.INTER_NEAREST)
# cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()
输入图片
结果
此为bn1层的一张特征图