一、离散小波变换(DWT+IDWT)
1、实验过程
为了更好的学习下采样DWT和上采样IDWT的过程,我们构建一个简单的网络模型,将一张图片先进行下采样,然后再进行上采样操作,最后我们比较输入和输出图片的是否完全一样?PSNR值!
所以这个操作是可逆的!一张图片先经过下采样然后再进行上采样,是可以完全恢复的!
理解这个F.conv_transpose2d函数
import torch
inputs = torch.randint(1,10, (1, 4, 1, 1))
print(inputs)
print(inputs.shape)
weights = torch.randint(1,10,(4, 1, 2, 2))
print(weights)
print(weights.shape)
y = F.conv_transpose2d(inputs, weights, groups=1, stride=2)
print(y)
print(y.shape)
inputs = torch.randint(1,10, (1, 4, 1, 1))
print(inputs)
print(inputs.shape)
weights = torch.randint(1,10,(1, 4, 2, 2))
print(weights)
print(weights.shape)
y1=F.conv2d(inputs, weights, padding=1,stride=1)
print(y1)
print(y1.shape)
2、代码
# -*- coding: utf-8 -*-
import numpy as np
import torch.nn.init as init
import torch.nn.functional as F
import torch
from torch import nn
import cv2
#离散小波变换 下采样操作
class DWTForward(nn.Module):
def __init__(self):
super(DWTForward, self).__init__()
#ll lh hl hh shape(2,2)
ll = np.array([[0.5, 0.5], [0.5, 0.5]])
lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
#filts shape(4,1,2,2)
filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
hl[None,::-1,::-1], hh[None,::-1,::-1]],
axis=0)
self.weight = nn.Parameter(
torch.tensor(filts).to(torch.get_default_dtype()),
requires_grad=False)
def forward(self, x):
print("输入灰度图的形状:", x.shape) #torch.Size([1, 1, 512, 512])
C = x.shape[1] ##通道数
print("生成下采样参数的形状:",self.weight.shape) # torch.Size([4, 1, 2, 2])
#如果通道数是C,那们filters:torch.Size([4C, 1, 2, 2])
filters = torch.cat([self.weight,] * C, dim=0)
print("生成下采样卷积核的形状:", filters.shape) # torch.Size([4, 1, 2, 2])
##其实这里进行分组卷积操作 由定义的卷积进行下采样操作
#输入X:(1,1,512,512) filters:(4,1,2,2) group=1,stride=2 padding = 1
#输出Y:(1,4,256,256)
y = F.conv2d(x, filters, groups=