最近项目用到了tps算法,opencv2封装的tps实现起来比较慢,于是用pytorch实现了一下,可以支持gpu加速,就很nice了,在这里记录一下!
1. 简介
薄板样条函数(TPS)是一种很常见的插值方法。因为它一般都是基于2D插值,所以经常用在在图像配准中。在两张图像中找出N个匹配点,应用TPS可以将这N个点形变到对应位置,同时给出了整个空间的形变(插值)。
2. 实现
1. opencv的tps使用
import cv2
import numpy as np
import random
import torch
from torchvision.transforms import ToTensor, ToPILImage
DEVICE = torch.device("cpu")
def choice3(img):
'''
产生波浪型文字
:param img:
:return:
'''
h, w = img.shape[0:2]
N = 5
pad_pix = 50
points = []
dx = int(w/ (N - 1))
for i in range( N):
points.append((dx * i, pad_pix))
points.append((dx * i, pad_pix + h))
#加边框
img = cv2.copyMakeBorder(img, pad_pix, pad_pix, 0, 0, cv2.BORDER_CONSTANT,
value=(int(img[0][0][0]), int(img[0][0][1]), int(img[0][0][2])))
#原点
source = np.array(points, np.int32)
source = source.reshape(1, -1, 2)
#随机扰动幅度
rand_num_pos = random.uniform(20, 30)
rand_num_neg = -1 * rand_num_pos
newpoints = []
for i in range(N):
rand = np.random.choice([rand_num_neg, rand_num_pos], p=[0.5, 0.5])
if(i == 1):
nx_up = points[2 * i][0]
ny_up = points[2 * i][1] + rand
nx_down = points[2 * i + 1][0]
ny_down = points[2 * i + 1][1] + rand
elif (i == 4):
rand = rand_num_neg if rand > 1 else rand_num_pos
nx_up = points[2 * i][0]
ny_up = points[2 * i][1] + rand
nx_down = points[2 * i + 1][0]
ny_down = points[2 * i + 1][1] + rand
else:
nx_up = points[2 * i][0]
ny_up = points[2 * i][1]
nx_down = points[2 * i + 1][0]
ny_down = points[2 * i + 1][1]
newpoints.append((nx_up, ny_up))
newpoints.append((nx_down, ny_down))
#target点
target = np.array(newpoints, np.int32)
target = target.reshape(1, -1, 2)
#计算matches
matches = []
for i in range(1, 2*N + 1):
matches.append(cv2.DMatch(i, i, 0))
return source, target, matches, img
def norm(points_int, width, height):
"""
将像素点坐标归一化至 -1 ~ 1
"""
points_int_clone = torch.from_numpy(points_int).detach().float().to(DEVICE)
x = ((points_int_clone * 2)[..., 0] / (width - 1) - 1)
y = ((points_int_clone * 2)[..., 1] / (height - 1) - 1)
return torch.stack([x, y], dim=-1).contiguous().view(-1, 2)
class TPS(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, X, Y, w, h, device):
""" 计算grid"""
grid = torch.ones(1, h, w, 2, device=device)
grid[:, :, :, 0] = torch.linspace(-1, 1, w)
grid[:, :, :, 1] = torch.linspace(-1, 1, h)[..., None]
grid = grid.view(-1, h * w, 2)
""" 计算W, A"""
n, k = X.shape[:2]
device = X.device
Z = torch.zeros(1, k + 3, 2, device=device)
P = torch.ones(n, k, 3, device=device)
L = torch.zeros(n, k + 3, k + 3, device=device)
eps = 1e-9
D2 = torch.pow(X[:, :, None, :] - X[:, None, :, :], 2).sum(-1)
K = D2 * torch.log(D2 + eps)
P[:, :, 1:] = X
Z[:, :k, :] = Y
L[:, :k, :k] = K
L[:, :k, k:] = P
L[:, k:, :k] = P.permute(0, 2, 1)
Q = torch.solve(Z, L)[0]
W, A = Q[:, :k], Q[:, k:]
""" 计算U """
eps = 1e-9
D2 = torch.pow(grid[:, :, None, :] - X[:, None, :, :], 2).sum(-1)
U = D2 * torch.log(D2 + eps)
""" 计算P """
n, k = grid.shape[:2]
device = grid.device
P = torch.ones(n, k, 3, device=device)
P[:, :, 1:] = grid
# grid = P @ A + U @ W
grid = torch.matmul(P, A) + torch.matmul(U, W)
return grid.view(-1, h, w, 2)
if __name__=='__main__':
# 弯曲水平文本
img = cv2.imread('data/test.jpg', cv2.IMREAD_COLOR)
source, target, matches, img = choice3(img)
# #opencv版tps
# tps = cv2.createThinPlateSplineShapeTransformer()
# tps.estimateTransformation(source, target, matches)
# img = tps.warpImage(img)
# cv2.imshow('test.png', img)
# cv2.imwrite('test.png', img)
# cv2.waitKey(0)
#torch实现tps
ten_img = ToTensor()(img).to(DEVICE)
h, w = ten_img.shape[1], ten_img.shape[2]
ten_source = norm(source, w, h)
ten_target = norm(target, w, h)
tps = TPS()
warped_grid = tps(ten_target[None, ...], ten_source[None, ...], w, h, DEVICE) #这个输入的位置需要归一化,所以用norm
ten_wrp = torch.grid_sampler_2d(ten_img[None, ...], warped_grid, 0, 0)
new_img_torch = np.array(ToPILImage()(ten_wrp[0].cpu()))
cv2.imshow('test.png', new_img_torch)
cv2.imwrite('test.png', new_img_torch)
cv2.waitKey(0)
3. 效果
- 贴个效果图对比:
上图可以看出,pytorch实现与cv2的tps的效果完全对齐,所以重点看耗时,接下来贴耗时的对比图(差距还是蛮大的,图片越大差距越大)
如果对你有帮助的话,希望给个赞,谢谢~
参考1:TPS 薄板样条插值 python的opencv实现
注,这个参考可以初步了解使用cv2的tps使用,但是具体细节上还存在错误
参考2:薄板样条函数(Thin plate splines)的讨论与分析
参考3:数值方法——薄板样条插值(Thin-Plate Spline)