初学习OD目标检测常用脚本

涉及所学内容,也想写个内容保存起来以便自用~

下面均是OD常用脚本,后续会持续更新!代码来自各位公开好心善良的博主,但由于没有记住各博主ID,因此无法实际艾特。特此感谢各位博主~

① SSD划分数据集txt文件

import os
import random

trainval_percent = 0.5
train_percent = 0.5
xmlfilepath = 'E:/Annotations'
txtsavepath = 'E:/ImageSets/Main'
total_xml = os.listdir(xmlfilepath)

num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)

ftrainval = open(txtsavepath + '/trainval.txt', 'w')
ftest = open(txtsavepath + '/test.txt', 'w')
ftrain = open(txtsavepath + '/train.txt', 'w')
fval = open(txtsavepath + '/val.txt', 'w')

for i in list:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftrain.write(name)
        else:
            fval.write(name)
    else:
        ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest.close()

效果:

② 统计XML标签文件中各标签类别及数量

# -*- coding:utf-8 -*-
import os
import xml.etree.ElementTree as ET
import numpy as np

np.set_printoptions(suppress=True, threshold=100000000)
import matplotlib
from PIL import Image


def parse_obj(xml_path, filename):
    tree = ET.parse(xml_path + filename)
    objects = []
    for obj in tree.findall('object'):
        obj_struct = {}
        obj_struct['name'] = obj.find('name').text
        objects.append(obj_struct)
    return objects


def read_image(image_path, filename):
    im = Image.open(image_path + filename)
    W = im.size[0]
    H = im.size[1]
    area = W * H
    im_info = [W, H, area]
    return im_info


if __name__ == '__main__':
    xml_path = 'ssd-2/RTTS/Annotations/'
    filenamess = os.listdir(xml_path)
    filenames = []
    for name in filenamess:
        name = name.replace('.xml', '')
        filenames.append(name)
    recs = {}
    obs_shape = {}
    classnames = []
    num_objs = {}
    obj_avg = {}
    for i, name in enumerate(filenames):
        recs[name] = parse_obj(xml_path, name + '.xml')
    for name in filenames:
        for object in recs[name]:
            if object['name'] not in num_objs.keys():
                num_objs[object['name']] = 1
            else:
                num_objs[object['name']] += 1
            if object['name'] not in classnames:
                classnames.append(object['name'])
    for name in classnames:
        print('{}:{}个'.format(name, num_objs[name]))
    print('信息统计算完毕。')

效果:

③ 统计XML标签文件中各标签类别


"""
Created on Mon Nov 28 08:26:28 2022

@author: zqq
"""


import xml.etree.ElementTree as ET
import os


file_path_xml = r"E:\image_haze\Data_Set\Reside_A_benchmark\Extended Version\RTTS\RTTS\Annotations\\" # 2个斜杠

labelName = set([])


files_list = os.listdir(file_path_xml)
for i in files_list:
    file_dir = file_path_xml + i
    tree = ET.ElementTree(file=file_dir)
    root = tree.getroot()
    ObjectSet = root.findall('object/name')
    for Object in ObjectSet:
        labelName.add(Object.text)
print(sorted(labelName))

效果:

④ YOLO类型标签 to XML类型

# yolo格式数据集转VOC格式的代码如下:
#
# 24行:更改类别名,顺序要按yolo标注的顺序写
#
# 67、101行:更改图片格式
#
# 107行:更改图片的路径
#
# 216、218、220行:更改文件夹路径地址

# -*- coding: utf-8 -*-

import os
import xml.etree.ElementTree as ET
from xml.dom.minidom import Document
import cv2

'''
import xml
xml.dom.minidom.Document().writexml()
def writexml(self,
             writer: Any,
             indent: str = "",
             addindent: str = "",
             newl: str = "",
             encoding: Any = None) -> None
'''


class YOLO2VOCConvert:
    def __init__(self, txts_path, xmls_path, imgs_path):
        self.txts_path = txts_path  # 标注的yolo格式标签文件路径
        self.xmls_path = xmls_path  # 转化为voc格式标签之后保存路径
        self.imgs_path = imgs_path  # 读取读片的路径各图片名字,存储到xml标签文件中
        self.classes = ['bicycle', 'bus', 'car', 'motorbike', 'person']

    # 从所有的txt文件中提取出所有的类别, yolo格式的标签格式类别为数字 0,1,...
    # writer为True时,把提取的类别保存到'./Annotations/classes.txt'文件中
    def search_all_classes(self, writer=False):
        # 读取每一个txt标签文件,取出每个目标的标注信息
        all_names = set()
        txts = os.listdir(self.txts_path)
        # 使用列表生成式过滤出只有后缀名为txt的标签文件
        txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
        print(len(txts), txts)
        # 11 ['0002030.txt', '0002031.txt', ... '0002039.txt', '0002040.txt']
        for txt in txts:
            txt_file = os.path.join(self.txts_path, txt)
            with open(txt_file, 'r') as f:
                objects = f.readlines()
                for object in objects:
                    object = object.strip().split(' ')
                    print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
                    all_names.add(int(object[0]))
            # print(objects)  # ['2 0.506667 0.553333 0.490667 0.658667\n', '0 0.496000 0.285333 0.133333 0.096000\n', '8 0.501333 0.412000 0.074667 0.237333\n']

        print("所有的类别标签:", all_names, "共标注数据集:%d张" % len(txts))

        return list(all_names)

    def yolo2voc(self):
        # 创建一个保存xml标签文件的文件夹
        if not os.path.exists(self.xmls_path):
            os.mkdir(self.xmls_path)

        # 把上面的两个循环改写成为一个循环:
        imgs = os.listdir(self.imgs_path)
        txts = os.listdir(self.txts_path)
        txts = [txt for txt in txts if not txt.split('.')[0] == "classes"]  # 过滤掉classes.txt文件
        print(txts)
        # 注意,这里保持图片的数量和标签txt文件数量相等,且要保证名字是一一对应的   (后面改进,通过判断txt文件名是否在imgs中即可)
        if len(imgs) == len(txts):  # 注意:./Annotation_txt 不要把classes.txt文件放进去
            map_imgs_txts = [(img, txt) for img, txt in zip(imgs, txts)]
            txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
            print(len(txts), txts)
            for img_name, txt_name in map_imgs_txts:
                # 读取图片的尺度信息
                img_name = txt_name.split('.')[0] + '.jpg'
                print("读取图片:", img_name)
                img = cv2.imread(os.path.join(self.imgs_path, img_name))
                height_img, width_img, depth_img = img.shape
                print(height_img, width_img, depth_img)  # h 就是多少行(对应图片的高度), w就是多少列(对应图片的宽度)

                # 获取标注文件txt中的标注信息
                all_objects = []
                txt_file = os.path.join(self.txts_path, txt_name)
                with open(txt_file, 'r') as f:
                    objects = f.readlines()
                    for object in objects:
                        object = object.strip().split(' ')
                        all_objects.append(object)
                        print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']

                # 创建xml标签文件中的标签
                xmlBuilder = Document()
                # 创建annotation标签,也是根标签
                annotation = xmlBuilder.createElement("annotation")

                # 给标签annotation添加一个子标签
                xmlBuilder.appendChild(annotation)

                # 创建子标签folder
                folder = xmlBuilder.createElement("folder")
                # 给子标签folder中存入内容,folder标签中的内容是存放图片的文件夹,例如:JPEGImages
                folderContent = xmlBuilder.createTextNode(self.imgs_path.split('/')[-1])  # 标签内存
                folder.appendChild(folderContent)  # 把内容存入标签
                annotation.appendChild(folder)  # 把存好内容的folder标签放到 annotation根标签下

                # 创建子标签filename
                filename = xmlBuilder.createElement("filename")
                # 给子标签filename中存入内容,filename标签中的内容是图片的名字,例如:000250.jpg
                filenameContent = xmlBuilder.createTextNode(txt_name.split('.')[0] + '.jpg')  # 标签内容
                filename.appendChild(filenameContent)
                annotation.appendChild(filename)

                # path
                path = xmlBuilder.createElement("path")
                pathContent = xmlBuilder.createTextNode(
                    '/home/seucar/Sunyx/ssd.pytorch-master/data/VOCdevkit/VOC2007/JPEGImages/' + txt_name.split('.')[
                        0] + '.jpg')
                path.appendChild(pathContent)
                annotation.appendChild(path)

                # source
                source = xmlBuilder.createElement("source")
                database = xmlBuilder.createElement("database")
                databaseContent = xmlBuilder.createTextNode('Unknown')
                database.appendChild(databaseContent)
                source.appendChild(database)
                annotation.appendChild(source)

                # 把图片的shape存入xml标签中
                size = xmlBuilder.createElement("size")
                # 给size标签创建子标签width
                width = xmlBuilder.createElement("width")  # size子标签width
                widthContent = xmlBuilder.createTextNode(str(width_img))
                width.appendChild(widthContent)
                size.appendChild(width)  # 把width添加为size的子标签
                # 给size标签创建子标签height
                height = xmlBuilder.createElement("height")  # size子标签height
                heightContent = xmlBuilder.createTextNode(str(height_img))  # xml标签中存入的内容都是字符串
                height.appendChild(heightContent)
                size.appendChild(height)  # 把width添加为size的子标签
                # 给size标签创建子标签depth
                depth = xmlBuilder.createElement("depth")  # size子标签width
                depthContent = xmlBuilder.createTextNode(str(depth_img))
                depth.appendChild(depthContent)
                size.appendChild(depth)  # 把width添加为size的子标签
                annotation.appendChild(size)  # 把size添加为annotation的子标签

                # segmented
                segmented = xmlBuilder.createElement("segmented")
                segmentedContent = xmlBuilder.createTextNode('0')
                segmented.appendChild(segmentedContent)
                annotation.appendChild(segmented)

                # 每一个object中存储的都是['2', '0.506667', '0.553333', '0.490667', '0.658667']一个标注目标
                for object_info in all_objects:
                    # 开始创建标注目标的label信息的标签
                    object = xmlBuilder.createElement("object")  # 创建object标签
                    # 创建label类别标签
                    # 创建name标签
                    imgName = xmlBuilder.createElement("name")  # 创建name标签
                    imgNameContent = xmlBuilder.createTextNode(self.classes[int(object_info[0])])
                    imgName.appendChild(imgNameContent)
                    object.appendChild(imgName)  # 把name添加为object的子标签

                    # 创建pose标签
                    pose = xmlBuilder.createElement("pose")
                    poseContent = xmlBuilder.createTextNode("Unspecified")
                    pose.appendChild(poseContent)
                    object.appendChild(pose)  # 把pose添加为object的标签

                    # 创建truncated标签
                    truncated = xmlBuilder.createElement("truncated")
                    truncatedContent = xmlBuilder.createTextNode("0")
                    truncated.appendChild(truncatedContent)
                    object.appendChild(truncated)

                    # 创建difficult标签
                    difficult = xmlBuilder.createElement("difficult")
                    difficultContent = xmlBuilder.createTextNode("0")
                    difficult.appendChild(difficultContent)
                    object.appendChild(difficult)

                    # 先转换一下坐标
                    # (objx_center, objy_center, obj_width, obj_height)->(xmin,ymin, xmax,ymax)
                    x_center = float(object_info[1]) * width_img + 1
                    y_center = float(object_info[2]) * height_img + 1
                    xminVal = int(x_center - 0.5 * float(object_info[3]) * width_img)  # object_info列表中的元素都是字符串类型
                    yminVal = int(y_center - 0.5 * float(object_info[4]) * height_img)
                    xmaxVal = int(x_center + 0.5 * float(object_info[3]) * width_img)
                    ymaxVal = int(y_center + 0.5 * float(object_info[4]) * height_img)

                    # 创建bndbox标签(三级标签)
                    bndbox = xmlBuilder.createElement("bndbox")
                    # 在bndbox标签下再创建四个子标签(xmin,ymin, xmax,ymax) 即标注物体的坐标和宽高信息
                    # 在voc格式中,标注信息:左上角坐标(xmin, ymin) (xmax, ymax)右下角坐标
                    # 1、创建xmin标签
                    xmin = xmlBuilder.createElement("xmin")  # 创建xmin标签(四级标签)
                    xminContent = xmlBuilder.createTextNode(str(xminVal))
                    xmin.appendChild(xminContent)
                    bndbox.appendChild(xmin)
                    # 2、创建ymin标签
                    ymin = xmlBuilder.createElement("ymin")  # 创建ymin标签(四级标签)
                    yminContent = xmlBuilder.createTextNode(str(yminVal))
                    ymin.appendChild(yminContent)
                    bndbox.appendChild(ymin)
                    # 3、创建xmax标签
                    xmax = xmlBuilder.createElement("xmax")  # 创建xmax标签(四级标签)
                    xmaxContent = xmlBuilder.createTextNode(str(xmaxVal))
                    xmax.appendChild(xmaxContent)
                    bndbox.appendChild(xmax)
                    # 4、创建ymax标签
                    ymax = xmlBuilder.createElement("ymax")  # 创建ymax标签(四级标签)
                    ymaxContent = xmlBuilder.createTextNode(str(ymaxVal))
                    ymax.appendChild(ymaxContent)
                    bndbox.appendChild(ymax)

                    object.appendChild(bndbox)
                    annotation.appendChild(object)  # 把object添加为annotation的子标签
                f = open(os.path.join(self.xmls_path, txt_name.split('.')[0] + '.xml'), 'w')
                xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
                f.close()


if __name__ == '__main__':
    # 把yolo的txt标签文件转化为voc格式的xml标签文件
    # yolo格式txt标签文件相对路径
    txts_path1 = './labels'
    # 转化为voc格式xml标签文件存储的相对路径
    xmls_path1 = './Annotations'
    # 存放图片的相对路径
    imgs_path1 = './JPEGImages'

    yolo2voc_obj1 = YOLO2VOCConvert(txts_path1, xmls_path1, imgs_path1)
    labels = yolo2voc_obj1.search_all_classes()
    print('labels: ', labels)
    yolo2voc_obj1.yolo2voc()

⑤ 将文件夹的图片分成固定的文件夹内 如共10000张图片,分到3个文件夹中保存

import os
import shutil

# 指定原始图片所在的文件夹路径和目标文件夹路径
original_folder = "D:/BaiduNetdiskDownload/2"
target_folder = "D:/BaiduNetdiskDownload/2_1"

# 定义每个子文件夹中包含的图片数量
batch_size = 5000

# 获取原始文件夹中的所有图片文件名
all_files = os.listdir(original_folder)
image_files = [f for f in all_files if f.endswith(".jpg") or f.endswith(".png")]
#创建子文件夹并将图片移动到相应的子文件夹中
batch_num = 1
batch_folder = os.path.join(target_folder, f"batch{batch_num}")
os.makedirs(batch_folder, exist_ok=True)
#
for i, image_file in enumerate(image_files):
    if i > 0 and i % batch_size == 0:
        batch_num += 1
        batch_folder = os.path.join(target_folder, f"batch{batch_num}")
        os.makedirs(batch_folder, exist_ok=True)

    source_path = os.path.join(original_folder, image_file)
    target_path = os.path.join(batch_folder, image_file)
    shutil.move(source_path, target_path)

print(f"{len(image_files)} images have been moved to {batch_num} subfolders.")

⑥ 视频插帧处理

import os
import cv2


def decode_video(video_path, save_dir, target_num=None):
    '''
    video_path: 待解码的视频
    save_dir: 抽帧图片的保存文件夹
    target_num: 抽帧的数量, 为空则解码全部帧, 默认抽全部帧
    '''
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    video = cv2.VideoCapture()
    if not video.open(video_path):
        print("can not open the video")
        exit(1)
    count = 0
    index = 0
    frames_num = video.get(7)
    # 如果target_num为空就全部抽帧,不为空就抽target_num帧
    if target_num is None:
        step = 1
        print('all frame num is {}, decode all'.format(int(frames_num)))
    else:
        step = int(frames_num / target_num)
        print('all frame num is {}, decode sample num is {}'.format(int(frames_num), int(target_num)))
    while True:
        _, frame = video.read()
        if frame is None:
            break
        if count % step == 0:
            save_path = "{}/{:>04d}.png".format(save_dir, index)
            cv2.imwrite(save_path, frame)
            index += 1
        count += 1
        if index == frames_num and target_num is None:
            # 如果全部抽,抽到所有帧的最后一帧就停止
            break
        elif index == target_num and target_num is not None:
            # 如果采样抽,抽到target_num就停止
            break
        else:
            pass
    video.release()


if __name__ == '__main__':
    video_path = 'D:/BaiduNetdiskDownload/video/foggy.mp4'
    save_dir_1 = 'D:/BaiduNetdiskDownload/video_images/2'
    save_dir_2 = './images_sample'
    decode_video(video_path, save_dir_1)
    decode_video(video_path, save_dir_1)



import cv2
import os
import threading
def video_to_frames(video_path, outPutDirName):
    times = 0

    # 提取视频的频率,每1帧提取一个
    frame_frequency = 1

    # 如果文件目录不存在则创建目录
    if not os.path.exists(outPutDirName):
        os.makedirs(outPutDirName)

    # 读取视频帧
    camera = cv2.VideoCapture(video_path)

    while True:
        times = times + 1
        res, image = camera.read()
        if not res:
            print('not res , not image')
            break
        if times % frame_frequency == 0:
            cv2.imwrite(outPutDirName + '\\' + str(times) + '.jpg', image)

    print('图片提取结束')
    camera.release()


if __name__ == "__main__":
    input_dir = 'D:/BaiduNetdiskDownload/雾天汽车行驶视频.mp4'  # 输入的video文件夹位置
    save_dir = r'D:\BaiduNetdiskDownload\video_images'  # 输出图片到当前目录video文件夹下
    count = 0  # 视频数
    for video_name in os.listdir(input_dir):
        video_path = os.path.join(input_dir, video_name)
        outPutDirName = os.path.join(save_dir, video_name[:-4])
        threading.Thread(target=video_to_frames, args=(video_path, outPutDirName)).start()
        count = count + 1
        print("%s th video has been finished!" % count)

⑦按照自定义划分文件夹

import os
import shutil

# 设置输入和输出文件夹路径
input_folder = "H:/foggy_train_images"
output_folder = "H:/foggy"

# 设置需要分离的字符串
split_string = "0.02"

# 遍历输入文件夹中的所有文件
for filename in os.listdir(input_folder):
    # 跳过非图像文件
    if not filename.endswith(".jpg") and not filename.endswith(".png"):
        continue

    # 检查文件名是否包含指定的分离字符串
    if split_string not in filename:
        continue

    # 创建对应的输出文件夹
    output_path = os.path.join(output_folder, split_string)
    os.makedirs(output_path, exist_ok=True)

    # 将图像文件复制到对应的输出文件夹中
    image_path = os.path.join(input_folder, filename)
    output_file = os.path.join(output_path, filename)
    shutil.copy(image_path, output_file)

⑧以txt文件末端判断(按照自定义),移动txt文件

import os
import shutil

# 设置原始文件夹和目标文件夹的路径
src_folder = 'F:/foggy_cityscapes_all_info/foggy_city_txt/foggy_city_txt/foggy_city_txt_val/'
dst_folder = 'F:/foggy_cityscapes_all_info/foggy_city_txt/0.02_val/'
if not os.path.exists(dst_folder):
    os.makedirs(dst_folder)
# 获取原始文件夹中以“0.02.txt”结尾的文件列表
file_list = [os.path.join(src_folder, f) for f in os.listdir(src_folder) if f.endswith('0.02.txt')]

# 将符合条件的文件复制到目标文件夹中
for file_path in file_list:
    shutil.copy(file_path, dst_folder)

⑨ 查看coco数据格式.txt文件的标注文件是否存在重复框(建议做好原文件的存储哦~)

import os
from collections import defaultdict


def remove_duplicates(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    unique_labels = set()
    new_lines = []
    for line in lines:
        label = line.strip()
        if label not in unique_labels:
            unique_labels.add(label)
            new_lines.append(line)

    with open(file_path, 'w') as file:
        for line in new_lines:
            file.write(line)

    return len(lines) - len(new_lines)


def remove_duplicates_in_folder(folder_path):
    num_deleted = 0
    deleted_files = defaultdict(int)

    for root, dirs, files in os.walk(folder_path):
        for file_name in files:
            if file_name.endswith(".txt"):
                file_path = os.path.join(root, file_name)
                deleted_count = remove_duplicates(file_path)
                if deleted_count > 0:
                    deleted_files[file_path] += deleted_count
                    num_deleted += deleted_count

    print(f"Duplicate labels have been removed from {num_deleted} labels.")

    if num_deleted > 0:
        print("Deleted files:")
        for file_path, count in deleted_files.items():
            print(f"{file_path}: {count}")


# 使用示例
folder_path = 'H:/labels_new/test/'  # 替换为实际的文件夹路径
remove_duplicates_in_folder(folder_path)

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用Python和TensorFlow实现目标检测的示例代码: 1.导入必要的库和模块 ``` import numpy as np import os import six.moves.urllib as urllib import sys import tarfile import tensorflow as tf import zipfile from collections import defaultdict from io import StringIO from matplotlib import pyplot as plt from PIL import Image ``` 2.添加TensorFlow模型库到系统路径 ``` sys.path.append("..") ``` 3.从TensorFlow模型库中导入目标检测API ``` from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util ``` 4.下载和解压缩预训练模型 ``` MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017' MODEL_FILE = MODEL_NAME + '.tar.gz' DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/' PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' opener = urllib.request.URLopener() opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) tar_file = tarfile.open(MODEL_FILE) for file in tar_file.getmembers(): file_name = os.path.basename(file.name) if 'frozen_inference_graph.pb' in file_name: tar_file.extract(file, os.getcwd()) ``` 5.加载标签图和类别映射 ``` PATH_TO_LABELS = os.path.join('object_detection', 'data', 'mscoco_label_map.pbtxt') NUM_CLASSES = 90 label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) ``` 6.加载预训练模型 ``` detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ``` 7.创建会话并运行目标检测 ``` def run_inference_for_single_image(image, graph): with graph.as_default(): with tf.Session() as sess: # 输入和输出张量的名称 image_tensor = graph.get_tensor_by_name('image_tensor:0') detection_boxes = graph.get_tensor_by_name('detection_boxes:0') detection_scores = graph.get_tensor_by_name('detection_scores:0') detection_classes = graph.get_tensor_by_name('detection_classes:0') num_detections = graph.get_tensor_by_name('num_detections:0') # 执行目标检测 (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: np.expand_dims(image, 0)}) # 过滤掉分数低于阈值的目标 boxes = np.squeeze(boxes) scores = np.squeeze(scores) classes = np.squeeze(classes).astype(np.int32) indices = np.where(scores > 0.5)[0] boxes = boxes[indices] scores = scores[indices] classes = classes[indices] # 返回检测结果 return boxes, scores, classes # 加载测试图片 PATH_TO_TEST_IMAGE = 'test.jpg' image = Image.open(PATH_TO_TEST_IMAGE) image_np = np.array(image) # 运行目标检测 boxes, scores, classes = run_inference_for_single_image(image_np, detection_graph) # 可视化检测结果 vis_util.visualize_boxes_and_labels_on_image_array( image_np, boxes, classes, scores, category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=(12,8)) plt.imshow(image_np) plt.show() ``` 注意:上述代码中的“PATH_TO_TEST_IMAGE”需要替换为您的测试图像的路径。此外,还需要根据您的模型更改“MODEL_NAME”和“MODEL_FILE”。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值