自定义卷积核——pytorch
自定义一些常见的滤波器卷积核,定义成类,可以在网络框架中直接调用
定义类,写入自定义卷积核,用于网络处理图片
高斯滤波器写入网络框架
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import cv2
import matplotlib.pyplot as plt
class Gaussfilter(nn.Module):
def __init__(self, channels=1, kernel_size=3, sigma=1.5):
super(Gaussfilter, self).__init__()
self.channels = channels
self.k_size = kernel_size
self.sigma = sigma
x_data, y_data = np.mgrid[-(self.k_size) // 2 + 1:(self.k_size) // 2 + 1, -(self.k_size) // 2 + 1:(self.k_size) // 2 + 1]
x_ = torch.FloatTensor(x_data.astype(np.float32)).unsqueeze(0).unsqueeze(0)
y_ = torch.FloatTensor(y_data.astype(np.float32)).unsqueeze(0).unsqueeze(0)
g = torch.exp(-((x_ ** 2 + y_ ** 2) / torch.tensor(2.0 * self.sigma ** 2)))
g = g / torch.sum(g)
kernel = torch.FloatTensor(g)
# kernel = np.repeat(kernel, self.channels, axis=0)
self.weight = nn.Parameter(data=kernel, requires_grad=False)
def __call__(self, x): # 输入的X应该维度增加过.unsqueeze(0).unsqueeze(0)
x = x.unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, self.weight, stride=1, padding=1, groups=self.channels)
return x
定义类,调用类测试自定义高斯滤波器卷积核,实现将图像高低频信息分离。方便网络单独处理单频信息。
"""
test
"""
input_x = plt.imread('F:/BGR_demo/monarch_gray.jpg')
# cv2.imshow("input_x", input_x)
input_x = Variable(torch.from_numpy(input_x.astype(np.float32)/255))
gaussian_conv = Gaussfilter(1, 3, 1.5)
out_x = gaussian_conv(input_x)
img_H = input_x - out_x
out_x = (out_x.squeeze(0).squeeze(0)*255).data.numpy().astype(np.uint8)
img_H = (img_H.squeeze(0).squeeze(0)*255).data.numpy().astype(np.uint8)
# cv2.imshow("out_x", out_x)
# cv2.waitKey(0)
cv2.imshow("img_H", img_H)
cv2.waitKey(0)