PANet:基于金字塔注意力网络的图像超分辨率重建
本文为全代码,原文请看:传送门
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset,SubsetRandomSampler
import torch.optim as optim
from torchvision.utils import save_image
import os
import cv2
import random as ra
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2),stride=stride, bias=bias)
class BasicBlock(nn.Sequential):
def __init__(
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
bn=False, act=nn.PReLU()):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
class PyramidAttention(nn.Module):
def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, conv=default_conv):
super(PyramidAttention, self).__init__()
self.ksize = ksize
self.stride = stride
self.res_scale = res_scale
self.softmax_scale = softmax_scale
self.scale = [1-i/10 for i in range(level)]
self.average = average
escape_NaN = torch.FloatTensor([1e-4])
self.register_buffer('escape_NaN', escape_NaN)
self.conv_match_L_base = BasicBlock(conv,channel,channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_match = BasicBlock(conv,channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_assembly = BasicBlock(conv,channel, channel,1,bn=False, act=nn.PReLU())
def forward(self, input):
res = input
#theta
match_base = self.conv_match_L_base(input)
shape_base = list(res.size())
input_groups = torch.split(match_base,1,dim=0)
# patch size for matching
kernel = self.ksize
# raw_w is for reconstruction
raw_w = []
# w is for matching
w = []
#build feature pyramid
for i in range(len(self.scale)):
ref = input
if self.scale[i]!=1:
ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic',
align_corners=True,recompute_scale_factor=True)
#feature transformation function f
base = self.conv_assembly(ref)
shape_input = base.shape
#sampling
raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
strides=[self.stride,self.stride],
rates=[1, 1],
padding='same') # [N, C*k*k, L]
raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
raw_w.append(raw_w_i_groups)
#feature transformation function g
ref_i = self.conv_match(ref)
shape_ref = ref_i.shape
#sampling
w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same')
w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
w_i_groups = torch.split(w_i, 1, dim=0)
w.append(w_i_groups)
y = []
for idx, xi in enumerate(input_groups):
#group in a filter
wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))],dim=0) # [L, C, k, k]
#normalize
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
axis=[1, 2, 3],
keepdim=True)),
self.escape_NaN)
wi_normed = wi/ max_wi
#matching
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
yi = yi.view(1,wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32)
# softmax matching score
yi = F.softmax(yi*self.softmax_scale, dim=1)
if self.average == False:
yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
# deconv for patch pasting
raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))],dim=0)
yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride,padding=1)/4.
y.append(yi)
y = torch.cat(y, dim=0)+res*self.res_scale # back to the mini-batch
return y
class PreprocessDataset(Dataset):
def __init__(self,path,size = 96):
super().__init__()
self.size = size
self.allImgs = list()
for root,dirs,files in os.walk(path):
self.allImgs = [os.path.join(root,file) for file in files]
ra.shuffle(self.allImgs)
def __len__(self):
return len(self.allImgs)
def __getitem__(self,index):
img = self.allImgs[index]
img = cv2.imread(img)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
height,width,_ = img.shape
try:
xStart = ra.randint(0,width-self.size-1)
yStart = ra.randint(0,height-self.size-1)
except:
img = self.allImgs[ra.randint(0,len(self)-1)]
img = cv2.imread(img)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
height,width,_ = img.shape
xStart = ra.randint(0,width-self.size-1)
yStart = ra.randint(0,height-self.size-1)
img = img[yStart:self.size + yStart,xStart:self.size + xStart,:]
if ra.random() > 0.5:
img = cv2.flip(img,1)
hr = torch.tensor(np.transpose(img,(2,0,1)))/255.0
hr = (hr - 0.5)/0.5
lr = F.max_pool2d(hr,2)
return hr,lr
class ResBlock(nn.Module):
def __init__(self,inChannals):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(inChannals,inChannals,kernel_size = 1,bias = False),
nn.BatchNorm2d(inChannals),
nn.ReLU(inplace = True),
nn.Conv2d(inChannals,inChannals,kernel_size = 3,stride = 1,
padding = 1,bias = False,padding_mode = 'reflect'),
nn.BatchNorm2d(inChannals)
)
def forward(self,input):
return F.relu(input + self.model(input),inplace = True)
class Sequential(nn.Sequential):
def __init__(self,inChannals,blockNum = 8):
seq = [ResBlock(inChannals) for _ in range(blockNum)]
seq.insert(int(blockNum/2),PyramidAttention(channel=inChannals, level=4))
super().__init__(*seq)
class Model(nn.Module):
def __init__(self,channals = 64,blockNum = 6):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3,channals,kernel_size = 7,padding = 3,stride = 1,
padding_mode = 'reflect',bias = False),
nn.BatchNorm2d(channals),
nn.ReLU(inplace = True),
nn.Conv2d(channals,channals,kernel_size = 3,padding = 1,stride = 1,
padding_mode = 'reflect',bias = False),
nn.BatchNorm2d(channals),
nn.ReLU(inplace = True)
)
self.sequential = Sequential(channals,blockNum)
self.upSample = nn.Sequential(
nn.Conv2d(channals,channals * 4,kernel_size = 3,padding = 1,stride = 1,
padding_mode = 'reflect'),
nn.PixelShuffle(2),
nn.Conv2d(channals,channals,kernel_size = 3,padding = 1,stride = 1),
nn.ReLU(inplace = True),
nn.Conv2d(channals,3,kernel_size = 1,stride = 1),
nn.Tanh()
)
def forward(self,input):
features = self.features(input)
output = self.sequential(features)
output = features + output
output = self.upSample(output)
return output
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
def PSRN(img1, img2):
mse = torch.mean((img1 - img2) ** 2)
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
def update_lr(optimizer, multiplier = .1):
state_dict = optimizer.state_dict()
for param_group in state_dict['param_groups']:
param_group['lr'] = param_group['lr'] * multiplier
optimizer.load_state_dict(state_dict)
if __name__ == '__main__':
path = '../COCO/COCO_COCO_2017_Unlabeled_Images/unlabeled2017/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = PreprocessDataset(path,size = 96)
trainData = DataLoader(dataset,batch_size = 32,num_workers = 4,shuffle = True)
net = Model(channals = 64,blockNum = 24).to(device)
print(net)
criteria = nn.L1Loss()
optimizer = optim.AdamW(net.parameters(),lr = 1e-4)
totalStep = len(trainData)
modelPath = './checkpoint/PANet.pth'
if os.path.exists(modelPath):
checkpoint = torch.load(modelPath)
net.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
startEpoch = checkpoint['epoch']
else:
startEpoch = 0
if not os.path.exists('./checkpoint'):
os.mkdir('./checkpoint')
if not os.path.exists('./img'):
os.mkdir('./img')
for epoch in range(startEpoch,10000):
if epoch == 20 or epoch == 40:
update_lr(optimizer, multiplier = .1)
totalSSIM = 0.0
totalPSRN = 0.0
totalLoss = 0.0
for step,(hr,lr) in enumerate(trainData,1):
net.train(True)
hr,lr = hr.to(device),lr.to(device)
net.zero_grad()
output = net(lr)
loss = criteria(output,hr)
loss.backward()
optimizer.step()
totalLoss += loss
totalSSIM += ssim(output,hr)
totalPSRN += PSRN(output,hr)
print("[Epoch %d] Step: %d/%d Loss: %.4f|ssim: %.4f|psrn: %.4f" %
(epoch,step,totalStep,totalLoss/step,totalSSIM/step,totalPSRN/step))
if step >= 100:
net.train(False)
outputs = net(lr)
outputs = torch.cat([hr,outputs],dim = 0)
save_image(outputs,'./Img/Result_epoch_%08d.jpg' % epoch,nrow = 8,normalize = True)
state_dict = {'net': net.state_dict(),'optimizer':optimizer.state_dict(),'epoch':epoch}
torch.save(state_dict,'./checkpoint/PANet.pth')
break