原始图片:
import random
import math
import torch
import numpy as np
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
from data_gen import draw_box_points
path = './test/timg.jpeg'
im_data = cv2.imread(path)
img = im_data.copy()
# plt.imshow(im_data)
# plt.show()
# 参数设置
debug = True
norm_height = 44
gt = np.asarray([[205,150],[202,126],[365,93],[372,111]])
im_data = torch.from_numpy(im_data).unsqueeze(0)
im_data = im_data.permute(0,3,1,2).to(torch.float)
center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4
dw = gt[2, :] - gt[1, :]
dh = gt[1, :] - gt[0, :]
w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + random.randint(-2, 2)
angle_gt = ( math.atan2((gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2((gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0]) ) / 2
input_W = im_data.size(3)
input_H = im_data.size(2)
target_h = norm_height
scale = target_h / h
target_gw = int(w * scale) + random.randint(0, int(target_h))
target_gw = max(8, int(round(target_gw / 4)) * 4)
xc = center[0]
yc = center[1]
w2 = w
h2 = h
scalex = (w2 + random.randint(0, int(h2))) / input_W
scaley = h2 / input_H
th11 = scalex * math.cos(angle_gt)
th12 = -math.sin(angle_gt) * scaley
th13 = (2 * xc - input_W - 1) / (input_W - 1) #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)
th21 = math.sin(angle_gt) * scalex
th22 = scaley * math.cos(angle_gt)
th23 = (2 * yc - input_H - 1) / (input_H - 1) #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)
t = np.asarray([th11, th12, th13, th21, th22, th23], dtype=np.float)
t = torch.from_numpy(t).type(torch.FloatTensor)
theta = t.view(-1, 2, 3)
grid = F.affine_grid(theta, torch.Size((1, 3, int(target_h ), int(target_gw))))
x = F.grid_sample(im_data, grid)
if debug:
x_c = x.data.cpu().numpy()[0]
x_data_draw = x_c.swapaxes(0, 2)
x_data_draw = x_data_draw.swapaxes(0, 1)
x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
x_data_draw = x_data_draw[:, :, ::-1]
cv2.circle(img, (int(center[0]), int(center[1])), 5, (0, 255, 0))
cv2.imshow('im_data', x_data_draw)
# draw_box_points(img, pts)
draw_box_points(img, gt, color=(0, 0, 255))
cv2.imshow('img', img)
cv2.waitKey(100)
裁剪出来的图片:
效果还是有的,但是采用了pytorch的affine_grid和grid_sample,并不知道theta矩阵的计算方式。
将文字区域调整到同样的高度,不同的长度,但是字会出现左右(最左,最右的字)会超出文字区域。
第二种方案
采用rroi_align的方式进行旋转矫正和crop操作,并使用cuda进行运算加速
实现细节
结果
可以看到效果比上面的结果好多了。
下一步工作
实现批量的rroi_align操作