具体代码如下:
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from PIL import Image import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" myim = Image.open("data/Lenna.png") myimgray = np.array(myim.convert("L"),dtype=np.float32) plt.figure(figsize=(6,6)) plt.imshow(myimgray,cmap=plt.cm.gray) plt.axis("off") plt.show() imh , imw = myimgray.shape myimgray_t = torch.from_numpy(myimgray.reshape((1,1,imh,imw))) print("数组转张量后的shape:",myimgray_t.shape) kersize = 5 #定义边缘检测卷积核,并将维度处理为1*1*5*5 ker = torch.ones(kersize, kersize, dtype=torch.float32) * -1 ker[2,2] = 24 ker = ker.reshape((1,1,kersize,kersize)) ##进行卷积操作 conv2d = nn.Conv2d(1,2,(kersize,kersize),bias = False) ##设置卷积时使用的核,第一个核使用边缘检测核 conv2d.weight.data[0] = ker ##对灰度图进行卷积操作 imconv2dout = conv2d(myimgray_t) ##对卷积后的输出进行维度压缩 imconv2dout_im = imconv2dout.data.squeeze() print("卷积后尺寸:",imconv2dout_im.shape) ##可视化卷积后的图像 plt.figure(figsize=(12,6)) plt.subplot(1,2,1) plt.imshow(imconv2dout_im[0],cmap=plt.cm.gray) plt.axis("off") plt.subplot(1,2,2) plt.imshow(imconv2dout_im[1],cmap=plt.cm.gray) plt.axis("off") plt.show()
使用边缘特征提取卷积核很好的提取出了图像的边缘信息;右图使用的卷积核为随机数,得到的卷积结果与原图很相似。