算法实习生

windows传到Ubuntu服务器

scp -r C:/Users/k167/Desktop/dataset/person_dataset/ ubuntu@192.168.31.35:/ai_projects/tmp/person_dataset/

统计当前文件夹下文件的个数:

ls -l |grep "^-"|wc -l

统计当前文件夹下目录的个数:

ls -l |grep "^d"|wc -l

统计当前文件夹下文件的个数,包括子文件夹里的 :

ls -lR|grep "^-"|wc -l

统计文件夹下目录的个数,包括子文件夹里的:

ls -lR|grep "^d"|wc -l

退出vim

:q

wc

 wc val.txt

在这里插入图片描述

替换路径

:%s/C:\\Users\\k167\\Desktop\\dataset/\\ai_projects\\tmp/g

在这里插入图片描述
在这里插入图片描述

nvidia

nvitop

在这里插入图片描述

xml2yolo
输入数据是img_list.txt里面包含所有的图片数据,bsuval用于验证,tokyo用于训练,去掉攀岩,用kmp算法去匹配

from tqdm import tqdm
import xml.etree.ElementTree as ET
import os
classes = ["person"]
class Solution:
    # 获取next数组
    def get_next(self, T):
        i = 0
        j = -1
        next_val = [-1] * len(T)
        while i < len(T)-1:
            if j == -1 or T[i] == T[j]:
                i += 1
                j += 1
                # next_val[i] = j
                if i < len(T) and T[i] != T[j]:
                    next_val[i] = j
                else:
                    next_val[i] = next_val[j]
            else:
                j = next_val[j]
        return next_val

    # KMP算法
    def kmp(self, S, T):
        i = 0
        j = 0
        next = self.get_next(T)
        while i < len(S) and j < len(T):
            if j == -1 or S[i] == T[j]:
                i += 1
                j += 1
            else:
                j = next[j]
        if j == len(T):
            return i - j
        else:
            return -1
def convert(size, box):
    dw = 1.0 / size[0]
    dh = 1.0 / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


def convert_annotation(img_path):
    # print(img_path)
    # txt_path = os.path.dirname(img_path)  # 转换后的txt文件存放路径
    # print(img_path[:-3])
    out_file = open(img_path[:-3] + 'txt', 'w')
    # print(out_file)
    assert img_path.split('.')[-1]=='jpg' or img_path.split('.')[-1]=='png'
    xml_path = img_path.replace('jpg','xml').replace('png','xml')

    f = open(xml_path)
    xml_text = f.read()
    root = ET.fromstring(xml_text)
    f.close()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        cls = obj.find('name').text
        if cls not in classes:
            print('不存在',cls)
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
def get_txt(path):
    f = open(path)
    # print(len(f.readlines()))
    path_head = path.split('img_list.txt')[0]
    # print('path_head',path_head)
    txt = []
    line = f.readline().strip().split(' ')[0]  # 读取第一行
    while line:  # 直到读取完文件
        if line.split('/')[0] != 'bg_images':
            txt.append(path_head+line)
        line = f.readline().strip().split(' ')[0]   # 读取一行文件,包括换行符
    f.close()  # 关闭文件
    return txt
def main():
    # 读图片路径
    txt_path = r'C:/Users/k167/Desktop/dataset/annotation_images/img_list.txt'
    #创建class.txt
    root = os.path.dirname(txt_path)
    with open(os.path.join(root, 'classes.txt'), 'w') as f:
        # 写入classes.txt
        for i, category in enumerate(classes):
            f.write(f"{category}\n")
    # 读运动的路径
    txt = get_txt(txt_path)
    #创建train.txt val.txt
    train = open(os.path.join(root, 'train.txt'), 'w')
    val = open(os.path.join(root, 'val.txt'), 'w')
    crop = open(os.path.join(root, 'crop.txt'), 'w')
    for path in tqdm(txt):
        s = Solution()
        if(s.kmp(path, 'panyan')==-1):
            if(s.kmp(path, 'tokyo')>-1):
                train.write(path + '\n')
            if(s.kmp(path, 'bsuval')>-1):
                val.write(path + '\n')
            if(s.kmp(path, 'bsu')>-1):
                crop.write(path + '\n')
        convert_annotation(path)
if __name__ == '__main__':
    main()

读txt文件

def get_txt(path):
    f = open(path)
    # path_head = os.path.dirname(path)
    # print('path_head',path_head)
    txt = []
    line = f.readline().strip().split(' ')[0]  # 读取第一行
    while line:  # 直到读取完文件

        if line.split('/')[0] != 'bg_images':
            txt.append(line)
        line = f.readline().strip().split(' ')[0]   # 读取一行文件,包括换行符

    f.close()  # 关闭文件
    # print(txt)
    print(len(txt))
    return txt

读xml voc格式数据

def read_annotations(xml_path):
    import xml.etree.cElementTree as ET
    et = ET.parse(xml_path)
    element = et.getroot()
    element_objs = element.findall('object')
    element_width = int(element.find('size').find('width').text)
    element_height = int(element.find('size').find('height').text)

    results = []
    if element_width <= 0 or element_height <= 0:
        return []
    for element_obj in element_objs:
        class_name = str(element_obj.find('name').text)

        if class_name is None:
            return []
        else:
            obj_bbox = element_obj.find('bndbox')
            x1 = int(round(float(obj_bbox.find('xmin').text)))
            y1 = int(round(float(obj_bbox.find('ymin').text)))
            x2 = int(round(float(obj_bbox.find('xmax').text)))
            y2 = int(round(float(obj_bbox.find('ymax').text)))
            if x1 < 0 or y1 < 0 or x2 > element_width or y2 > element_height or x1 >= x2 or y1 >= y2:
                continue
            results.append([class_name, element_width, element_height, x1, y1, x2, y2])
    return results

创建文件夹

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

修改voc数据 有多个目标,
传入的box格式如下
[[‘person’, 1.0, [1, 600, 385, 1339], [], [], [], []], [‘person’, 0.9092891812324524, [469, 648, 615, 981], [], [], [], []], [‘person’, 0.8290438055992126, [675, 660, 726, 783], [], [], [], []], [‘person’, 0.788838267326355, [641, 662, 680, 778], [], [], [], []], [‘person’, 0.6277031302452087, [38, 586, 179, 1048], [], [], [], []], [‘person’, 0.6049114465713501, [412, 678, 437, 745], [], [], [], []]]

import xml.etree.ElementTree as ET
def create_object(root, xi, yi, xa, ya, obj_name):  # 参数依次,树根,xmin,ymin,xmax,ymax
    # 创建一级分支object
    _object = ET.SubElement(root, 'object')
    # 创建二级分支
    name = ET.SubElement(_object, 'name')
    # print(obj_name)
    name.text = str(obj_name)
    pose = ET.SubElement(_object, 'pose')
    pose.text = 'Unspecified'
    truncated = ET.SubElement(_object, 'truncated')
    truncated.text = '0'
    difficult = ET.SubElement(_object, 'difficult')
    difficult.text = '0'
    # 创建bndbox
    bndbox = ET.SubElement(_object, 'bndbox')
    xmin = ET.SubElement(bndbox, 'xmin')
    xmin.text = '%s' % xi
    ymin = ET.SubElement(bndbox, 'ymin')
    ymin.text = '%s' % yi
    xmax = ET.SubElement(bndbox, 'xmax')
    xmax.text = '%s' % xa
    ymax = ET.SubElement(bndbox, 'ymax')
    ymax.text = '%s' % ya
def change_xml(crop,boxs,xml_path,save_xml_path):
    # print(boxs)
    if(len(boxs) == 0):
        return
    updateTree = ET.parse(xml_path)  # 读取待修改文件
    root = updateTree.getroot()
    size = root.find('size')
    width = size.find('width')
    width.text = str(crop.shape[0])
    height = size.find('height')
    height.text = str(crop.shape[1])

    element_obj = root.find('object')
    obj_bbox = element_obj.find('bndbox')
    xmin = obj_bbox.find('xmin')  # 找到filename标签,
    xmin.text = str(boxs[0][2][0])  # 修改标签内容
    ymin = obj_bbox.find('ymin')  # 找到filename标签,
    ymin.text = str(boxs[0][2][1])  # 修改标签内容
    xmax = obj_bbox.find('xmax')  # 找到filename标签,
    xmax.text = str(boxs[0][2][2])  # 修改标签内容
    ymax = obj_bbox.find('ymax')  # 找到filename标签,
    ymax.text = str(boxs[0][2][3])  # 修改标签内容

    if(len(boxs) > 1):
        for box in boxs[1:]:
            create_object(root, box[2][0], box[2][1], box[2][2], box[2][3], 'person')

    # print(path)
    updateTree.write(save_xml_path.replace(".png", ".xml").replace(".jpg", ".xml"))  # 保存修改

裁剪图片,以bbox为中心,向外pad,进行裁剪
img(2160, 3840, 3)
bbox[730, 824, 1122, 1274]

def expand_img(img, bbox, img_h, img_w, pad, square=True):
    if isinstance(pad, tuple) or isinstance(pad, list):
        pad_w, pad_h = int(pad[0]), int(pad[1])
    else:
        pad_w, pad_h = pad, pad
    x1, y1, x2, y2 = [int(i) for i in bbox]
    h, w = (y2 - y1), (x2 - x1)
    if square:
        h = w = max(h, w)
    ctx, cty = (x1 + x2) // 2, (y1 + y2) // 2
    x1 = ctx - w // 2 - pad_w
    y1 = cty - h // 2 - pad_h
    x2 = ctx + w // 2 + pad_w
    y2 = cty + h // 2 + pad_h

    # print(x1, y1, x2, y2)
    bbox = [x1, y1, x2, y2]
    bbox[0] = max(0, min(img_w - 1, bbox[0]))
    bbox[1] = max(0, min(img_h - 1, bbox[1]))
    bbox[2] = max(0, min(img_w - 1, bbox[2]))
    bbox[3] = max(0, min(img_h - 1, bbox[3]))
    x1, y1, x2, y2 = bbox
    # print(x1, y1, x2, y2)
    crop_img = img[y1: y2, x1: x2, :3]
    return bbox, crop_img

找冰壶视频,抽帧,筛选
生成voc格式,yolo格式
voc用lableimg生成xml文件,然后将xml格式转为yolo格式
文件夹和文件名不能有空格,会报错 os.renames(dir_name, dir_name.replace(’ ‘,’_'))
2017_Best_Curling_Shots
2018_season_of_champions_shots
2021_AGI_Top_Shots_of_the_Year
2022_Tim_Horton’s_Brier_Top_Ten_shots
2023_Tim_Horton’s_Brier_Top_Ten_Shots为训练集
JAPAN_v_UNITED_STATES_Mixed_Doubles_Curling_Championship_2023为验证集
使用yolov8l.pt预训练模型,设置300epoch

from tqdm import tqdm
import xml.etree.ElementTree as ET
import os
classes = ["yellow", "red"]
def convert(size, box):
    dw = 1.0 / size[0]
    dh = 1.0 / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


def convert_annotation(img_path):
    # print(img_path)
    # txt_path = os.path.dirname(img_path)  # 转换后的txt文件存放路径
    # print(img_path[:-3])
    out_file = open(img_path.replace('jpg','txt').replace('png','txt'), 'w')
    # print(out_file)
    assert img_path.split('.')[-1]=='jpg' or img_path.split('.')[-1]=='png'
    xml_path = img_path.replace('jpg','xml').replace('png','xml')
    print(xml_path)
    f = open(xml_path, encoding='gb18030')
    xml_text = f.read()
    root = ET.fromstring(xml_text)
    f.close()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        cls = obj.find('name').text
        if cls not in classes:
            print('不存在',cls)
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
        # exit()
def get_txt(path):
    f = open(path)
    # print(len(f.readlines()))
    # path_head = path.split('img_list.txt')[0]
    # print('path_head',path_head)
    txt = []
    line = f.readline().strip()  # 读取第一行
    while line:  # 直到读取完文件

        txt.append(line)
        line = f.readline().strip()   # 读取一行文件,包括换行符

    f.close()  # 关闭文件
    # print(txt)
    # print(len(txt))
    return txt
def main():
    #生成img_list.txt
    img_dir = r'C:\Users\k167\Desktop\dataset\curling'

    paths = os.walk(img_dir)
    f = open(os.path.join(img_dir,'img_list.txt'), 'w')
    #创建train.txt val.txt
    train = open(os.path.join(img_dir, 'train.txt'), 'w')
    val = open(os.path.join(img_dir, 'val.txt'), 'w')
    # crop = open(os.path.join(root, 'crop.txt'), 'w')
    for path, dir_lst, file_lst in paths:
        if len(dir_lst) > 0:
            for dir_name in dir_lst:
                os.renames(dir_name, dir_name.replace(' ','_')) #文件夹改名
        for file_name in file_lst:
            if ' ' in file_name:
                os.chdir(path)
                os.renames(file_name, file_name.replace(' ', '_')) #文件改名
                file_name = file_name.replace(' ', '_')

            img_path = os.path.join(path, file_name)
 
            if (img_path.split('.')[-1] == 'jpg'):
                f.write(img_path+'\n')
                if(img_path.split('\\')[-2]=='JAPAN_v_UNITED_STATES_Mixed_Doubles_Curling_Championship_2023'):
                    val.write(img_path + '\n')
                else:
                    train.write(img_path + '\n')
    # 读图片路径
    txt_path = r'C:\Users\k167\Desktop\dataset\curling/img_list.txt'
    #创建class.txt
    root = os.path.dirname(txt_path)
    with open(os.path.join(root, 'classes.txt'), 'w') as f:
        # 写入classes.txt
        for i, category in enumerate(classes):
            f.write(f"{category}\n")
    # # exit()
    # 读运动的路径
    txt = get_txt(txt_path)
    # print(txt)

    for path in tqdm(txt):
        # print(path)

        convert_annotation(path)



if __name__ == '__main__':
    main()

在这里插入图片描述

当你执行 mask_np[mask_np == 1] = obj + 1 这行代码时,它会让所有等于1的元素被替换为 obj + 1
让我们来解释这个代码的步骤:

  1. mask_np == 1 会返回一个与 mask_np 具有相同形状的布尔数组,其中元素为True表示对应位置的元素等于1。
  2. 然后,将这个布尔数组应用于 mask_np,即只选取与 mask_np == 1 对应位置为True的元素。
  3. 最后,使用 = 运算符将选中的元素赋值为 obj + 1
    换句话说,这行代码会按照条件选取 mask_np 中值为1的元素,并将它们替换为 obj + 1 的值。这可以用来在数组中进行条件替>换或更新特定的元素值。

多目标跟踪生成的mask,mask背景为0,目标为1,2。
crop_frame[:, :, ::-1] 使用[:, :, ::-1]对数组进行切片,相当于反转第三维的顺序,即将BGR顺序转换为RGB顺序

for obj in range(num_objects):
	mask_np, _, seg_res = segment.run(crop_frame[:, :, ::-1], crop_box)
	mask_np = mask_np.astype(np.uint8)
	mask_np[mask_np == 1] = obj + 1

mask生成bbox

def get_pts(w):
    idx_w = np.array(list(range(len(w))))
    idx = idx_w[w > 0]
    x1, x2 = 0, 0
    if len(idx) > 2:
        x1, x2 = idx[0], idx[-1]
    return x1, x2


def get_bbox(thresh):
    # _, thresh = cv2.threshold(mask, 100, 255, cv2.THRESH_BINARY)
    w = np.sum(thresh, axis=0)
    h = np.sum(thresh, axis=1)
    x1, x2 = get_pts(w)
    y1, y2 = get_pts(h)
    return x1, y1, x2, y2
    
bboxes = []
    for obj in range(num_objects):
        mask = np.array([crop_mask == obj+1], dtype=np.uint8)
        bbox = list(get_bbox(mask.squeeze(0))) #mask[1,640,640]
        bboxes.append(bbox)

获得文件夹下的图片list

from tqdm import tqdm
import xml.etree.ElementTree as ET
import os


def main():
    # 生成img_list.txt
    img_dir = r'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'

    paths = os.walk(img_dir)
    f = open(os.path.join(img_dir, 'img_list.txt'), 'w')

    for path, dir_lst, file_lst in paths:
        if len(dir_lst) > 0:
            for dir_name in dir_lst:
                os.renames(dir_name, dir_name.replace(' ', '_'))  # 文件夹改名,不能有空格
        for file_name in file_lst:
            if ' ' in file_name:
                os.chdir(path)
                os.renames(file_name, file_name.replace(' ', '_'))  # 文件改名
                file_name = file_name.replace(' ', '_')

            img_path = os.path.join(path, file_name)

            if (img_path.split('.')[-1] == 'jpg'): #如果是jpg图片就写入
                f.write(img_path + '\n')


if __name__ == '__main__':
    main()

抽帧

import cv2
import os
video_path = r'C:\Users\k167\Desktop\dataset\10m_diving_sync/2012_Diving_Women_Sync_10m.mp4'  # 视频地址


output_path = r'C:\Users\k167\Desktop\dataset\10m_diving_sync/2012_Diving_Women_Sync_10m/'  # 输出文件夹
interval = 15  # 每间隔10帧取一张图片

os.makedirs(output_path, exist_ok=True)
if __name__ == '__main__':
    num = 1
    name = video_path.split('/')[1].split('.')[0]
    # name = output_path.split('/')[1]

    vid = cv2.VideoCapture(video_path)
    while vid.isOpened():
        is_read, frame = vid.read()
        if is_read:
            if num % interval == 0:
                file_name = '%s_%d' % (name, num)
                print(output_path + str(file_name) + '.jpg')
                cv2.imwrite(output_path + str(file_name) + '.jpg', frame)
                # 00000111.jpg 代表第111帧
                cv2.waitKey(1)
                # print(file_name, '.jpg')
            num += 1

        else:
            break

要求使用cutie生成裁剪的图片和xml,但是需要设置两个pad参数,一个为跟踪时pad的大小,一个为裁剪的大小。
pad_box 是pad为200的裁剪框,crop_pad_box 是pad为600的裁剪框,要保存600的裁剪框
创建一个与原始图像尺寸相同的空白掩码数组。
根据给定的裁剪框,在空白掩码数组中填充原始掩码数据,形成一个带有裁剪框的掩码图像。
根据另一个给定的裁剪区域,从带有裁剪框的掩码图像中截取出真正需要的裁剪后的掩码数据。

crop_input_frame, crop_frame_shape, pad_box = crop_and_pad_img_no_resize(frame, last_bbox,img_shape, input_shape,pad)
input_frame = cv2.resize(crop_input_frame, (tracker_inp_w, tracker_inp_h))
t22 = time.time()
frame_torch = image_to_torch(input_frame, device=device)
crop_prob = processor.step(frame_torch)
crop_mask = torch_prob_to_numpy_mask(crop_prob)
crop_mask_obj1 = (crop_mask == 1).astype(np.uint8)
crop_mask_obj2 = (crop_mask == 2).astype(np.uint8)
crop_mask_obj1 = cv2.resize(crop_mask_obj1, (crop_frame_shape[1], crop_frame_shape[0]), cv2.INTER_LINEAR)
crop_mask_obj2 = cv2.resize(crop_mask_obj2, (crop_frame_shape[1], crop_frame_shape[0]), cv2.INTER_LINEAR)
crop_mask = cv2.resize(crop_mask, (crop_frame_shape[1], crop_frame_shape[0]), cv2.INTER_LINEAR)
mask_c = np.zeros([crop_frame_shape[0], crop_frame_shape[1]], dtype=np.uint8)
# print(crop_mask_obj1.shape)
# print(crop_frame_shape[0], crop_frame_shape[1])
mask_c[crop_mask_obj1 == 1] = 1
mask_c[crop_mask_obj2 == 1] = 2

#裁剪图片的pad
crop_crop_input_frame, crop_crop_frame_shape, crop_pad_box = crop_and_pad_img_no_resize(frame, last_bbox, img_shape, input_shape,crop_pad)
mask_cr = np.zeros([img_h, img_w], dtype=np.uint8)
pad_x1, pad_y1, pad_x2, pad_y2 = pad_box
mask_cr[pad_y1: pad_y2, pad_x1: pad_x2] = mask_c
pad_x1, pad_y1, pad_x2, pad_y2 = crop_pad_box
pad_mask = mask_cr[pad_y1: pad_y2, pad_x1: pad_x2]

# if((current_frame_index - 1) % 10 == 0):
crop(video_dir, pad_mask, crop_crop_input_frame, num_objects, current_frame_index)

每个视频用cutie生成img和xml,以视频名为文件夹,每个项目20个
跟踪预测框 检测框 做iou,保留最大的那个
多人做匈牙利匹配,匹配的依据是计算目标边界框与检测结果边界框之间的IoU值,并使用匈牙利匹配算法找到最佳的匹配结果。

def crop(video_dir, crop_mask, input_frame, crop_input_frame, num_objects, current_frame_index,pad_box):
    save_root = r'C:\Users\k167\Desktop\dataset\time_plus_bsu_annotation'
    mkdir(save_root)
    save_dir = video_dir.replace('time_plus_bsu','time_plus_bsu_annotation')#C:\Users\k167\Desktop\dataset\time_plus_bsu_crop\tiaoshui10msynchronize
    mkdir(save_dir)
    save_img = os.path.join(save_dir, os.path.basename(video_dir).replace('.mp4', '_%s.jpg'%(current_frame_index)))
    xml_path = r'D:\szj\time_plus-main\testpic\biaoqiang2_tokyo2020_35_mp4_50.xml'
    save_xml_path = save_img.replace('.jpg', '.xml').replace('.png', '.xml')
    bboxes = []#[[1205, 495, 1279, 910], [904, 448, 1004, 993]](双目标为此格式,不是这个数据)
    ious = []#[[-0.3312384473197782, -0.9452054794520548, -0.0, -0.0], [-0.9546765249537893, -0.0, -0.0, -0.0]]

    if num_objects > 1:
        det_results = detector.run(crop_input_frame, detector.opt.vis_thresh)
        for det_res in det_results:#['person', 0.4215676486492157, [675, 680, 742, 890], [], [], [], []]
            det_res[2][0] += pad_box[0]
            det_res[2][1] += pad_box[1]
            det_res[2][2] += pad_box[0]
            det_res[2][3] += pad_box[1]
        det_boxs = [i[2] for i in det_results]#[[904, 447, 1004, 988], [1205, 495, 1278, 910], [756, 765, 830, 990], [675, 680, 742, 890]]

        for obj in range(num_objects):
            mask = np.array([crop_mask == obj + 1], dtype=np.uint8)
            bbox = list(get_bbox(mask.squeeze(0)))#[908, 450, 1004, 989]

            iou = []#[-0.9546765249537893, -0.0, -0.0, -0.0]
            for det_box in det_boxs:
                iou.append(cal_iou(bbox, det_box))
            ious.extend([iou])
        matches = linear_sum_assignment(ious)#([0 1], [1 0])
        for matche in matches[1]:
            if (det_boxs[matche][3] - det_boxs[matche][1]) > 0 and (det_boxs[matche][2] - det_boxs[matche][0]) > 0:
                bboxes.extend([det_boxs[matche]])
    else:
        for obj in range(num_objects):
            mask = np.array([crop_mask == obj + 1], dtype=np.uint8)
            bbox = list(get_bbox(mask.squeeze(0)))
            if (bbox[3] - bbox[1]) > 0 and (bbox[2] - bbox[0]) > 0:
                bboxes.append(bbox)


    if(len(bboxes) > 0):
        cv2.imwrite(save_img, input_frame)
        change_xml(input_frame, bboxes, xml_path, save_xml_path)

sam分割,cv2.selectROI选框

def init_roi(frame, idx):
    print("init{}: init roi....".format(idx))
    print("init{}: frame shape: {}".format(idx, frame.shape))
    cv2.namedWindow("select", 0)
    res = cv2.selectROI("select", frame.copy())
    cv2.destroyWindow("select")
    print("init{}: select: {}".format(idx, res))
    x, y, w, h = res
    assert w > 1 and h > 1, "please select rectangle"
    x1, y1, x2, y2 = x, y, x + w, y + h
    bbox = [x1, y1, x2, y2]
    print("init{} bbox: {}".format(idx, bbox))
    return bbox
class SAMNet(object):
    def __init__(self):
        from segment_anything import sam_model_registry, SamPredictor

        # model_type, model_path = "vit_b", r"D:\work\hzy\github\segment-anything\sam_vit_b_01ec64.pth"
        # model_type, model_path = "vit_l", r"D:\work\hzy\github\segment-anything\sam_vit_l_0b3195.pth"
        model_type, model_path = "vit_h", r"D:\szj\sam_weight\sam_vit_h_4b8939.pth"
        print("{} weight {}".format(self.__class__.__name__, model_path))
        sam = sam_model_registry[model_type](checkpoint=model_path)
        self.predictor = SamPredictor(sam.to(device="cuda"))
        print("init {} done...".format(self.__class__.__name__))
        self.img_size = 384

    def run(self, image, input_boxes):
        t1 = time.time()
        self.predictor.set_image(image)
        masks, score, logist = self.predictor.predict(point_coords=None, point_labels=None,
                                                      box=np.array(input_boxes), multimask_output=True)
        mask, logit = masks[np.argmax(score)], logist[np.argmax(score), :, :]
        print('masks {}, score {}'.format(masks.shape, score.tolist()))
        t2 = time.time()
        print("{} time cost {}".format(self.__class__.__name__, t2 - t1))
        mask_color = 3
        mask_alpha = 0.7
        contour_color = 1
        contour_width = 5
        painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color,
                                     contour_width)
        # painted_image = Image.fromarray(painted_image)

        return mask, logit, painted_image    
segment = SAMNet()
bbox = init_roi(img, obj)
mask_np, _, seg_res = segment.run(frame, bbox)
mask_np = mask_np.astype(np.uint8)
cv2.namedWindow("init", 0)
cv2.imshow("init", mask_np * int(255 / max(1, mask_np.max())))
key = cv2.waitKey(0)
if key == 27:
    break
cv2.destroyWindow("init")
df -h

在这里插入图片描述

找开源项目的一些途径
• https://github.com/trending/
• https://github.com/521xueweihan/HelloGitHub
• https://github.com/ruanyf/weekly
• https://www.zhihu.com/column/mm-fe
特殊的查找资源小技巧-常用前缀后缀
• 找百科大全 awesome xxx
• 找例子 xxx sample
• 找空项目架子 xxx starter / xxx boilerplate
• 找教程 xxx tutorial

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值