0. 所有需要用到的库
import torch
import numpy as np
from PIL import Image
from PIL import ImageEnhance, ImageDraw
from torchvision import transforms
import math
import random
1. 旋转
def random_rotate(img, bboxes, angle_range=(-10,10)):
# 随机旋转
angle = random.uniform(angle_range[0], angle_range[1])
img = img.rotate(-angle, expand=False)
w, h = img.size
cx = w / 2
cy = h / 2
theta = math.radians(angle)
cos_theta = math.cos(theta)
sin_theta = math.sin(theta)
new_bboxes = None
for bbox in bboxes:
x1, y1, x2, y2 = bbox
# 旋转变换
point1x, point1y = rotate_point(x1, y1,cx, cy, cos_theta, sin_theta)
point2x, point2y = rotate_point(x2, y2, cx, cy, cos_theta, sin_theta)
point3x,point3y = rotate_point(x1, y2, cx, cy, cos_theta, sin_theta)
point4x,point4y = rotate_point(x2, y1, cx, cy, cos_theta, sin_theta)
x1 = min(point1x, point2x, point3x, point4x)
y1 = min(point1y, point2y, point3y, point4y)
x2 = max(point1x, point2x, point3x, point4x)
y2 = max(point1y, point2y, point3y, point4y)
# new_bboxes.append([x1, y1, x2, y2])
if new_bboxes is not None:
# new_bboxes = torch.cat((new_bboxes, torch.Tensor()), 0)
new_bboxes = torch.cat((new_bboxes, torch.Tensor([x1, y1, x2, y2]).reshape(1, 4)), 0)
else:
new_bboxes = torch.Tensor([x1, y1, x2, y2]).reshape(1, 4)
print(new_bboxes)
print(new_bboxes.shape)
return img, new_bboxes
def rotate_point(x, y, cx, cy, cos_theta, sin_theta):
# 旋转变换
nx = cos_theta * (x - cx) - sin_theta * (y - cy) + cx
ny = sin_theta * (x - cx) + cos_theta * (y - cy) + cy
return nx, ny
2. 随机裁剪
def random_crop(img, bboxes, p=1):
# 随机裁剪
if random.random() < p:
h_img, w_img = img.size
print(bboxes.shape)
# 得到可以包含所有bbox的最大bbox
# max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1)
# bboxes 是tensor
max_bbox = torch.cat((torch.min(bboxes[:, 0:2], 0)[0], torch.max(bboxes[:, 2:4], 0)[0]), 0)
max_l_trans = max_bbox[0]
max_u_trans = max_bbox[1]
max_r_trans = w_img - max_bbox[2]
max_d_trans = h_img - max_bbox[3]
crop_xmin = max(0, int(max_bbox[0] - random.uniform(0, max_l_trans)))
crop_ymin = max(0, int(max_bbox[1] - random.uniform(0, max_u_trans)))
crop_xmax = max(w_img, int(max_bbox[2] + random.uniform(0, max_r_trans)))
crop_ymax = max(h_img, int(max_bbox[3] + random.uniform(0, max_d_trans)))
# img = img[crop_ymin : crop_ymax, crop_xmin : crop_xmax]
img = img.crop((crop_xmin, crop_ymin, crop_xmax, crop_ymax))
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin
return img, bboxes
3. 随机改变颜色亮度对比度
def random_distort(img):
# 随机改变亮度
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
# 随机改变对比度
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
# 随机改变颜色
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
# img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
# img = np.asarray(img)
return img
4. 随机水平翻转竖直翻转
def random_flip_horizon(self, img, boxes):
#-------------------------------------
# 随机水平翻转
#-------------------------------------
if np.random.random() > 0.5:
# transform=transforms.HorizontalFlip()
transform = transforms.RandomHorizontalFlip(p = 1.0)
img=transform(img)
w = img.width
xmin=w-boxes[:,2]
xmax=w-boxes[:,0]
boxes[:,0]=xmin
boxes[:,2]=xmax
return img, boxes
def random_flip_vertical(self, img, boxes):
#-------------------------------------
# 随机垂直翻转
#-------------------------------------
if np.random.random()>0.5:
transform=transforms.RandomVerticalFlip(p=1.0)
img=transform(img)
h=img.height
ymin=h-boxes[:, 3]
ymax=h-boxes[:, 1]
boxes[:, 1]=ymin
boxes[:, 3]=ymax
return img, boxes
5. 可视化代码
def visualization(self, img, boxes):
#---------------------------------------------------------------------
# 可视化,将boxes画在img上
#---------------------------------------------------------------------
draw = ImageDraw.Draw(img)
for i in range(boxes.size(0)):
box = boxes[i].int()
draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline='yellow', width=3)
img.show()
6. img 和bbox 的读取
其中我的bbox txt 的 格式是这样的
第一个是label 信息,后面四个是 x1,y1,x2,y2 bbox的对角线坐标
img = "outputs/001/100.jpg"
img = Image.open(img)
bboxs = open("outputs/detections/100.txt", "r").read().split("\n")
aa = None
for bb in bboxs:
if bb:
bb = bb.split(" ")[1:]
# print(bb)
bb = [float(i) for i in bb]
print(bb)
bb = torch.Tensor(bb).reshape(1, 4)
if aa is not None :
aa = torch.cat((aa, bb), 0)
else:
aa = bb
# boxes.append(bb)
# boxes = []
# print(bboxs)