An pytorch implementation of imresize function in MATLAB with bicubic kernel.
pytorch 版本
import torch
import torch.nn as nn
import numpy as np
class bicubic_imresize(nn.Module):
"""
An pytorch implementation of imresize function in MATLAB with bicubic kernel.
"""
def __init__(self):
super(bicubic_imresize, self).__init__()
def cubic(self, x):
absx = torch.abs(x)
absx2 = torch.abs(x) * torch.abs(x)
absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x)
condition1 = (absx <= 1).to(torch.float32)
condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32)
f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2
return f
def contribute(self, in_size, out_size, scale, cuda_flag):
kernel_width = 4
if scale < 1:
kernel_width = 4 / scale
x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32)
x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32)
if cuda_flag:
x0 = x0.cuda()
x1 = x1.cuda()
u0 = x0 / scale + 0.5 * (1 - 1 / scale)
u1 = x1 / scale + 0.5 * (1 - 1 / scale)
left0 = torch.floor(u0 - kernel_width / 2)
left1 = torch.floor(u1 - kernel_width / 2)
P = np.ceil(kernel_width) + 2
if cuda_flag:
indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
else:
indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0)
indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0)
mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0)
mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0)
if scale < 1:
weight0 = scale * self.cubic(mid0 * scale)
weight1 = scale * self.cubic(mid1 * scale)
else:
weight0 = self.cubic(mid0)
weight1 = self.cubic(mid1)
weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2))
weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2))
if cuda_flag:
indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0),
torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0)
indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1),
torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0)
else:
indice0 = torch.min(torch.max(torch.FloatTensor([1]), indice0),
torch.FloatTensor([in_size[0]])).unsqueeze(0)
indice1 = torch.min(torch.max(torch.FloatTensor([1]), indice1),
torch.FloatTensor([in_size[1]])).unsqueeze(0)
kill0 = torch.eq(weight0, 0)[0][0]
kill1 = torch.eq(weight1, 0)[0][0]
weight0 = weight0[:, :, kill0 == 0]
weight1 = weight1[:, :, kill1 == 0]
indice0 = indice0[:, :, kill0 == 0]
indice1 = indice1[:, :, kill1 == 0]
return weight0, weight1, indice0, indice1
def forward(self, input, scale=1 / 4):
[b, c, h, w] = input.shape
output_size = [b, c, int(h * scale), int(w * scale)]
cuda_flag = input.is_cuda
weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale, cuda_flag)
weight0 = weight0.squeeze(0)
indice0 = indice0.squeeze(0).long()
out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4))
out = (torch.sum(out, dim=3))
A = out.permute(0, 1, 3, 2)
weight1 = weight1.squeeze(0)
indice1 = indice1.squeeze(0).long()
out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4))
out = torch.sum(out, dim=3).permute(0, 1, 3, 2)
return out
numpy 版本
来源: https://github.com/fatheral/matlab_imresize
from __future__ import print_function
import numpy as np
from math import ceil, floor
from skimage import img_as_float
def deriveSizeFromScale(img_shape, scale):
output_shape = []
for k in range(2):
output_shape.append(int(ceil(scale[k] * img_shape[k])))
return output_shape
def deriveScaleFromSize(img_shape_in, img_shape_out):
scale = []
for k in range(2):
scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
return scale
def triangle(x):
x = np.array(x).astype(np.float64)
lessthanzero = np.logical_and((x >= -1), x < 0)
greaterthanzero = np.logical_and((x <= 1), x >= 0)
f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
return f
def cubic(x):
x = np.array(x).astype(np.float64)
absx = np.absolute(x)
absx2 = np.multiply(absx, absx)
absx3 = np.multiply(absx2, absx)
f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
(1 < absx) & (absx <= 2))
return f
def contributions(in_length, out_length, scale, kernel, k_width):
if scale < 1:
h = lambda x: scale * kernel(scale * x)
kernel_width = 1.0 * k_width / scale
else:
h = kernel
kernel_width = k_width
x = np.arange(1, out_length + 1).astype(np.float64)
u = x / scale + 0.5 * (1 - 1 / scale)
left = np.floor(u - kernel_width / 2)
P = int(ceil(kernel_width)) + 2
ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
indices = ind.astype(np.int32)
weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
indices = aux[np.mod(indices, aux.size)]
ind2store = np.nonzero(np.any(weights, axis=0))
weights = weights[:, ind2store]
indices = indices[:, ind2store]
return weights, indices
def imresizemex(inimg, weights, indices, dim):
in_shape = inimg.shape
w_shape = weights.shape
out_shape = list(in_shape)
out_shape[dim] = w_shape[0]
outimg = np.zeros(out_shape)
if dim == 0:
for i_img in range(in_shape[1]):
for i_w in range(w_shape[0]):
w = weights[i_w, :]
ind = indices[i_w, :]
im_slice = inimg[ind, i_img].astype(np.float64)
outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
elif dim == 1:
for i_img in range(in_shape[0]):
for i_w in range(w_shape[0]):
w = weights[i_w, :]
ind = indices[i_w, :]
im_slice = inimg[i_img, ind].astype(np.float64)
outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
if inimg.dtype == np.uint8:
outimg = np.clip(outimg, 0, 255)
return np.around(outimg).astype(np.uint8)
else:
return outimg
def imresizevec(inimg, weights, indices, dim):
wshape = weights.shape
if dim == 0:
weights = weights.reshape((wshape[0], wshape[2], 1, 1))
outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
elif dim == 1:
weights = weights.reshape((1, wshape[0], wshape[2], 1))
outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
if inimg.dtype == np.uint8:
outimg = np.clip(outimg, 0, 255)
return np.around(outimg).astype(np.uint8)
else:
return outimg
def resizeAlongDim(A, dim, weights, indices, mode="vec"):
if mode == "org":
out = imresizemex(A, weights, indices, dim)
else:
out = imresizevec(A, weights, indices, dim)
return out
def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
if I.dtype=='uint8' or np.max(I)>2:
I = img_as_float(I)
if method == 'bicubic':
kernel = cubic
elif method == 'bilinear':
kernel = triangle
else:
raise ValueError('unidentified kernel method supplied')
kernel_width = 4.0
# Fill scale and output_size
if scalar_scale is not None and output_shape is not None:
raise ValueError('either scalar_scale OR output_shape should be defined')
if scalar_scale is not None:
scalar_scale = float(scalar_scale)
scale = [scalar_scale, scalar_scale]
output_size = deriveSizeFromScale(I.shape, scale)
elif output_shape is not None:
scale = deriveScaleFromSize(I.shape, output_shape)
output_size = list(output_shape)
else:
raise ValueError('either scalar_scale OR output_shape should be defined')
scale_np = np.array(scale)
order = np.argsort(scale_np)
weights = []
indices = []
for k in range(2):
w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
weights.append(w)
indices.append(ind)
B = np.copy(I)
flag2D = False
if B.ndim == 2:
B = np.expand_dims(B, axis=2)
flag2D = True
for k in range(2):
dim = order[k]
B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
if flag2D:
B = np.squeeze(B, axis=2)
im_out = convertDouble2Byte(B)
return im_out
def convertDouble2Byte(I):
B = np.clip(I, 0.0, 1.0)
B = 255 * B
return np.around(B).astype(np.uint8)
if __name__ =='__main__':
from skimage import io
from skimage.measure import compare_psnr as psnr
im = io.imread('E:\\DATA\\retina\\clean_001_clahe.png')
[h,w]=im.shape
im_r = imresize(im, output_shape=[h//2,w//2])
im_r2= imresize(im_r, output_shape=[h,w])
print(psnr(im_r2,im)) # 26.2788 (MATLAB: 26.2772)
im = io.imread('E:\\DATA\\Set5\\original\\butterfly.png')
[h,w,c]=im.shape
im_r = imresize(im, output_shape=[h//2,w//2, c])
im_r2= imresize(im_r, output_shape=[h,w,c])
print(psnr(im_r2,im)) # 26.1284 (MATLAB: 26.1249)