前言
标注数据集对深度学习而言是很重要的一步,但是标注数据是件很繁琐的工作,而自动标注可以减轻工作量。通过训练好的模型检测目标,输出包含目标类别和位置的txt文件,然后再将其转换为xml文件,最后再使用标注工具进行完善。 使用的模型的精度越高越好,如果检测结果不准确,就会增加工作量,导致无法使用。
- 加载框架
- 加载模型并且自动标注
- 验证修改
部分代码
def detector(frame, model, device, half=True):
img_size = 640
img0 = frame
img = letterbox(img0, new_shape=img_size)[0]
img = img[:, :, :].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
with torch.no_grad():
pred = model(img)[0]
pred = non_max_suppression(pred, conf_thres=0.4, iou_thres=0.1)
for i, det in enumerate(pred):
if det is not None and len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
info_list = []
for *xyxy, conf, cls in det:
xyxy = torch.tensor(xyxy).view(-1).tolist()
info = [xyxy[0], xyxy[1], xyxy[2], xyxy[3], int(cls)]
info_list.append(info)
return info_list
else:
return None
def create_object(root, xi, yi, xa, ya, obj_name):
_object = ET.SubElement(root, 'object')
name = ET.SubElement(_object, 'name')
name.text = str(obj_name)
pose = ET.SubElement(_object, 'pose')
pose.text = 'Unspecified'
truncated = ET.SubElement(_object, 'truncated')
truncated.text = '0'
difficult = ET.SubElement(_object, 'difficult')
difficult.text = '0'
# 创建bndbox
bndbox = ET.SubElement(_object, 'bndbox')
xmin = ET.SubElement(bndbox, 'xmin')
xmin.text = '%s' % xi
ymin = ET.SubElement(bndbox, 'ymin')
ymin.text = '%s' % yi
xmax = ET.SubElement(bndbox, 'xmax')
xmax.text = '%s' % xa
ymax = ET.SubElement(bndbox, 'ymax')
ymax.text = '%s' % ya
def create_tree(image_name, h, w, imgdir):
global annotation
annotation = ET.Element('annotation')
folder = ET.SubElement(annotation, 'folder')
folder.text = (imgdir)
filename = ET.SubElement(annotation, 'filename')
filename.text = image_name
path = ET.SubElement(annotation, 'path')
path.text = getcwd() + '\{}\{}'.format(imgdir, image_name)
source = ET.SubElement(annotation, 'source')
database = ET.SubElement(source, 'database')
database.text = 'Unknown'
size = ET.SubElement(annotation, 'size')
width = ET.SubElement(size, 'width')
width.text = str(w)
height = ET.SubElement(size, 'height')
height.text = str(h)
depth = ET.SubElement(size, 'depth')
depth.text = '3'
segmented = ET.SubElement(annotation, 'segmented')
segmented.text = '0'
def pretty_xml(element, indent, newline, level=0):
if element:
if (element.text is None) or element.text.isspace():
element.text = newline + indent * (level + 1)
else:
element.text = newline + indent * (level + 1) + element.text.strip() + newline + indent * (level + 1)
temp = list(element)
for subelement in temp:
if temp.index(subelement) < (len(temp) - 1):
subelement.tail = newline + indent * (level + 1)
else:
subelement.tail = newline + indent * level
pretty_xml(subelement, indent, newline, level=level + 1)