windows传到Ubuntu服务器
scp -r C:/Users/k167/Desktop/dataset/person_dataset/ ubuntu@192.168.31.35:/ai_projects/tmp/person_dataset/
统计当前文件夹下文件的个数:
ls -l |grep "^-"|wc -l
统计当前文件夹下目录的个数:
ls -l |grep "^d"|wc -l
统计当前文件夹下文件的个数,包括子文件夹里的 :
ls -lR|grep "^-"|wc -l
统计文件夹下目录的个数,包括子文件夹里的:
ls -lR|grep "^d"|wc -l
退出vim
:q
wc
wc val.txt
替换路径
:%s/C:\\Users\\k167\\Desktop\\dataset/\\ai_projects\\tmp/g
nvidia
nvitop
xml2yolo
输入数据是img_list.txt里面包含所有的图片数据,bsuval用于验证,tokyo用于训练,去掉攀岩,用kmp算法去匹配
from tqdm import tqdm
import xml.etree.ElementTree as ET
import os
classes = ["person"]
class Solution:
# 获取next数组
def get_next(self, T):
i = 0
j = -1
next_val = [-1] * len(T)
while i < len(T)-1:
if j == -1 or T[i] == T[j]:
i += 1
j += 1
# next_val[i] = j
if i < len(T) and T[i] != T[j]:
next_val[i] = j
else:
next_val[i] = next_val[j]
else:
j = next_val[j]
return next_val
# KMP算法
def kmp(self, S, T):
i = 0
j = 0
next = self.get_next(T)
while i < len(S) and j < len(T):
if j == -1 or S[i] == T[j]:
i += 1
j += 1
else:
j = next[j]
if j == len(T):
return i - j
else:
return -1
def convert(size, box):
dw = 1.0 / size[0]
dh = 1.0 / size[1]
x = (box[0] + box[1]) / 2.0
y = (box[2] + box[3]) / 2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
def convert_annotation(img_path):
# print(img_path)
# txt_path = os.path.dirname(img_path) # 转换后的txt文件存放路径
# print(img_path[:-3])
out_file = open(img_path[:-3] + 'txt', 'w')
# print(out_file)
assert img_path.split('.')[-1]=='jpg' or img_path.split('.')[-1]=='png'
xml_path = img_path.replace('jpg','xml').replace('png','xml')
f = open(xml_path)
xml_text = f.read()
root = ET.fromstring(xml_text)
f.close()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
cls = obj.find('name').text
if cls not in classes:
print('不存在',cls)
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
def get_txt(path):
f = open(path)
# print(len(f.readlines()))
path_head = path.split('img_list.txt')[0]
# print('path_head',path_head)
txt = []
line = f.readline().strip().split(' ')[0] # 读取第一行
while line: # 直到读取完文件
if line.split('/')[0] != 'bg_images':
txt.append(path_head+line)
line = f.readline().strip().split(' ')[0] # 读取一行文件,包括换行符
f.close() # 关闭文件
return txt
def main():
# 读图片路径
txt_path = r'C:/Users/k167/Desktop/dataset/annotation_images/img_list.txt'
#创建class.txt
root = os.path.dirname(txt_path)
with open(os.path.join(root, 'classes.txt'), 'w') as f:
# 写入classes.txt
for i, category in enumerate(classes):
f.write(f"{category}\n")
# 读运动的路径
txt = get_txt(txt_path)
#创建train.txt val.txt
train = open(os.path.join(root, 'train.txt'), 'w')
val = open(os.path.join(root, 'val.txt'), 'w')
crop = open(os.path.join(root, 'crop.txt'), 'w')
for path in tqdm(txt):
s = Solution()
if(s.kmp(path, 'panyan')==-1):
if(s.kmp(path, 'tokyo')>-1):
train.write(path + '\n')
if(s.kmp(path, 'bsuval')>-1):
val.write(path + '\n')
if(s.kmp(path, 'bsu')>-1):
crop.write(path + '\n')
convert_annotation(path)
if __name__ == '__main__':
main()
读txt文件
def get_txt(path):
f = open(path)
# path_head = os.path.dirname(path)
# print('path_head',path_head)
txt = []
line = f.readline().strip().split(' ')[0] # 读取第一行
while line: # 直到读取完文件
if line.split('/')[0] != 'bg_images':
txt.append(line)
line = f.readline().strip().split(' ')[0] # 读取一行文件,包括换行符
f.close() # 关闭文件
# print(txt)
print(len(txt))
return txt
读xml voc格式数据
def read_annotations(xml_path):
import xml.etree.cElementTree as ET
et = ET.parse(xml_path)
element = et.getroot()
element_objs = element.findall('object')
element_width = int(element.find('size').find('width').text)
element_height = int(element.find('size').find('height').text)
results = []
if element_width <= 0 or element_height <= 0:
return []
for element_obj in element_objs:
class_name = str(element_obj.find('name').text)
if class_name is None:
return []
else:
obj_bbox = element_obj.find('bndbox')
x1 = int(round(float(obj_bbox.find('xmin').text)))
y1 = int(round(float(obj_bbox.find('ymin').text)))
x2 = int(round(float(obj_bbox.find('xmax').text)))
y2 = int(round(float(obj_bbox.find('ymax').text)))
if x1 < 0 or y1 < 0 or x2 > element_width or y2 > element_height or x1 >= x2 or y1 >= y2:
continue
results.append([class_name, element_width, element_height, x1, y1, x2, y2])
return results
创建文件夹
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
修改voc数据 有多个目标,
传入的box格式如下
[[‘person’, 1.0, [1, 600, 385, 1339], [], [], [], []], [‘person’, 0.9092891812324524, [469, 648, 615, 981], [], [], [], []], [‘person’, 0.8290438055992126, [675, 660, 726, 783], [], [], [], []], [‘person’, 0.788838267326355, [641, 662, 680, 778], [], [], [], []], [‘person’, 0.6277031302452087, [38, 586, 179, 1048], [], [], [], []], [‘person’, 0.6049114465713501, [412, 678, 437, 745], [], [], [], []]]
import xml.etree.ElementTree as ET
def create_object(root, xi, yi, xa, ya, obj_name): # 参数依次,树根,xmin,ymin,xmax,ymax
# 创建一级分支object
_object = ET.SubElement(root, 'object')
# 创建二级分支
name = ET.SubElement(_object, 'name')
# print(obj_name)
name.text = str(obj_name)
pose = ET.SubElement(_object, 'pose')
pose.text = 'Unspecified'
truncated = ET.SubElement(_object, 'truncated')
truncated.text = '0'
difficult = ET.SubElement(_object, 'difficult')
difficult.text = '0'
# 创建bndbox
bndbox = ET.SubElement(_object, 'bndbox')
xmin = ET.SubElement(bndbox, 'xmin')
xmin.text = '%s' % xi
ymin = ET.SubElement(bndbox, 'ymin')
ymin.text = '%s' % yi
xmax = ET.SubElement(bndbox, 'xmax')
xmax.text = '%s' % xa
ymax = ET.SubElement(bndbox, 'ymax')
ymax.text = '%s' % ya
def change_xml(crop,boxs,xml_path,save_xml_path):
# print(boxs)
if(len(boxs) == 0):
return
updateTree = ET.parse(xml_path) # 读取待修改文件
root = updateTree.getroot()
size = root.find('size')
width = size.find('width')
width.text = str(crop.shape[0])
height = size.find('height')
height.text = str(crop.shape[1])
element_obj = root.find('object')
obj_bbox = element_obj.find('bndbox')
xmin = obj_bbox.find('xmin') # 找到filename标签,
xmin.text = str(boxs[0][2][0]) # 修改标签内容
ymin = obj_bbox.find('ymin') # 找到filename标签,
ymin.text = str(boxs[0][2][1]) # 修改标签内容
xmax = obj_bbox.find('xmax') # 找到filename标签,
xmax.text = str(boxs[0][2][2]) # 修改标签内容
ymax = obj_bbox.find('ymax') # 找到filename标签,
ymax.text = str(boxs[0][2][3]) # 修改标签内容
if(len(boxs) > 1):
for box in boxs[1:]:
create_object(root, box[2][0], box[2][1], box[2][2], box[2][3], 'person')
# print(path)
updateTree.write(save_xml_path.replace(".png", ".xml").replace(".jpg", ".xml")) # 保存修改
裁剪图片,以bbox为中心,向外pad,进行裁剪
img(2160, 3840, 3)
bbox[730, 824, 1122, 1274]
def expand_img(img, bbox, img_h, img_w, pad, square=True):
if isinstance(pad, tuple) or isinstance(pad, list):
pad_w, pad_h = int(pad[0]), int(pad[1])
else:
pad_w, pad_h = pad, pad
x1, y1, x2, y2 = [int(i) for i in bbox]
h, w = (y2 - y1), (x2 - x1)
if square:
h = w = max(h, w)
ctx, cty = (x1 + x2) // 2, (y1 + y2) // 2
x1 = ctx - w // 2 - pad_w
y1 = cty - h // 2 - pad_h
x2 = ctx + w // 2 + pad_w
y2 = cty + h // 2 + pad_h
# print(x1, y1, x2, y2)
bbox = [x1, y1, x2, y2]
bbox[0] = max(0, min(img_w - 1, bbox[0]))
bbox[1] = max(0, min(img_h - 1, bbox[1]))
bbox[2] = max(0, min(img_w - 1, bbox[2]))
bbox[3] = max(0, min(img_h - 1, bbox[3]))
x1, y1, x2, y2 = bbox
# print(x1, y1, x2, y2)
crop_img = img[y1: y2, x1: x2, :3]
return bbox, crop_img
找冰壶视频,抽帧,筛选
生成voc格式,yolo格式
voc用lableimg生成xml文件,然后将xml格式转为yolo格式
文件夹和文件名不能有空格,会报错 os.renames(dir_name, dir_name.replace(’ ‘,’_'))
2017_Best_Curling_Shots
2018_season_of_champions_shots
2021_AGI_Top_Shots_of_the_Year
2022_Tim_Horton’s_Brier_Top_Ten_shots
2023_Tim_Horton’s_Brier_Top_Ten_Shots为训练集
JAPAN_v_UNITED_STATES_Mixed_Doubles_Curling_Championship_2023为验证集
使用yolov8l.pt预训练模型,设置300epoch
from tqdm import tqdm
import xml.etree.ElementTree as ET
import os
classes = ["yellow", "red"]
def convert(size, box):
dw = 1.0 / size[0]
dh = 1.0 / size[1]
x = (box[0] + box[1]) / 2.0
y = (box[2] + box[3]) / 2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
def convert_annotation(img_path):
# print(img_path)
# txt_path = os.path.dirname(img_path) # 转换后的txt文件存放路径
# print(img_path[:-3])
out_file = open(img_path.replace('jpg','txt').replace('png','txt'), 'w')
# print(out_file)
assert img_path.split('.')[-1]=='jpg' or img_path.split('.')[-1]=='png'
xml_path = img_path.replace('jpg','xml').replace('png','xml')
print(xml_path)
f = open(xml_path, encoding='gb18030')
xml_text = f.read()
root = ET.fromstring(xml_text)
f.close()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
cls = obj.find('name').text
if cls not in classes:
print('不存在',cls)
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
# exit()
def get_txt(path):
f = open(path)
# print(len(f.readlines()))
# path_head = path.split('img_list.txt')[0]
# print('path_head',path_head)
txt = []
line = f.readline().strip() # 读取第一行
while line: # 直到读取完文件
txt.append(line)
line = f.readline().strip() # 读取一行文件,包括换行符
f.close() # 关闭文件
# print(txt)
# print(len(txt))
return txt
def main():
#生成img_list.txt
img_dir = r'C:\Users\k167\Desktop\dataset\curling'
paths = os.walk(img_dir)
f = open(os.path.join(img_dir,'img_list.txt'), 'w')
#创建train.txt val.txt
train = open(os.path.join(img_dir, 'train.txt'), 'w')
val = open(os.path.join(img_dir, 'val.txt'), 'w')
# crop = open(os.path.join(root, 'crop.txt'), 'w')
for path, dir_lst, file_lst in paths:
if len(dir_lst) > 0:
for dir_name in dir_lst:
os.renames(dir_name, dir_name.replace(' ','_')) #文件夹改名
for file_name in file_lst:
if ' ' in file_name:
os.chdir(path)
os.renames(file_name, file_name.replace(' ', '_')) #文件改名
file_name = file_name.replace(' ', '_')
img_path = os.path.join(path, file_name)
if (img_path.split('.')[-1] == 'jpg'):
f.write(img_path+'\n')
if(img_path.split('\\')[-2]=='JAPAN_v_UNITED_STATES_Mixed_Doubles_Curling_Championship_2023'):
val.write(img_path + '\n')
else:
train.write(img_path + '\n')
# 读图片路径
txt_path = r'C:\Users\k167\Desktop\dataset\curling/img_list.txt'
#创建class.txt
root = os.path.dirname(txt_path)
with open(os.path.join(root, 'classes.txt'), 'w') as f:
# 写入classes.txt
for i, category in enumerate(classes):
f.write(f"{category}\n")
# # exit()
# 读运动的路径
txt = get_txt(txt_path)
# print(txt)
for path in tqdm(txt):
# print(path)
convert_annotation(path)
if __name__ == '__main__':
main()
当你执行
mask_np[mask_np == 1] = obj + 1
这行代码时,它会让所有等于1的元素被替换为obj + 1
。
让我们来解释这个代码的步骤:
mask_np == 1
会返回一个与mask_np
具有相同形状的布尔数组,其中元素为True表示对应位置的元素等于1。- 然后,将这个布尔数组应用于
mask_np
,即只选取与mask_np == 1
对应位置为True的元素。- 最后,使用
=
运算符将选中的元素赋值为obj + 1
。
换句话说,这行代码会按照条件选取mask_np
中值为1的元素,并将它们替换为obj + 1
的值。这可以用来在数组中进行条件替>换或更新特定的元素值。
多目标跟踪生成的mask,mask背景为0,目标为1,2。
crop_frame[:, :, ::-1] 使用[:, :, ::-1]对数组进行切片,相当于反转第三维的顺序,即将BGR顺序转换为RGB顺序
for obj in range(num_objects):
mask_np, _, seg_res = segment.run(crop_frame[:, :, ::-1], crop_box)
mask_np = mask_np.astype(np.uint8)
mask_np[mask_np == 1] = obj + 1
mask生成bbox
def get_pts(w):
idx_w = np.array(list(range(len(w))))
idx = idx_w[w > 0]
x1, x2 = 0, 0
if len(idx) > 2:
x1, x2 = idx[0], idx[-1]
return x1, x2
def get_bbox(thresh):
# _, thresh = cv2.threshold(mask, 100, 255, cv2.THRESH_BINARY)
w = np.sum(thresh, axis=0)
h = np.sum(thresh, axis=1)
x1, x2 = get_pts(w)
y1, y2 = get_pts(h)
return x1, y1, x2, y2
bboxes = []
for obj in range(num_objects):
mask = np.array([crop_mask == obj+1], dtype=np.uint8)
bbox = list(get_bbox(mask.squeeze(0))) #mask[1,640,640]
bboxes.append(bbox)
获得文件夹下的图片list
from tqdm import tqdm
import xml.etree.ElementTree as ET
import os
def main():
# 生成img_list.txt
img_dir = r'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
paths = os.walk(img_dir)
f = open(os.path.join(img_dir, 'img_list.txt'), 'w')
for path, dir_lst, file_lst in paths:
if len(dir_lst) > 0:
for dir_name in dir_lst:
os.renames(dir_name, dir_name.replace(' ', '_')) # 文件夹改名,不能有空格
for file_name in file_lst:
if ' ' in file_name:
os.chdir(path)
os.renames(file_name, file_name.replace(' ', '_')) # 文件改名
file_name = file_name.replace(' ', '_')
img_path = os.path.join(path, file_name)
if (img_path.split('.')[-1] == 'jpg'): #如果是jpg图片就写入
f.write(img_path + '\n')
if __name__ == '__main__':
main()
抽帧
import cv2
import os
video_path = r'C:\Users\k167\Desktop\dataset\10m_diving_sync/2012_Diving_Women_Sync_10m.mp4' # 视频地址
output_path = r'C:\Users\k167\Desktop\dataset\10m_diving_sync/2012_Diving_Women_Sync_10m/' # 输出文件夹
interval = 15 # 每间隔10帧取一张图片
os.makedirs(output_path, exist_ok=True)
if __name__ == '__main__':
num = 1
name = video_path.split('/')[1].split('.')[0]
# name = output_path.split('/')[1]
vid = cv2.VideoCapture(video_path)
while vid.isOpened():
is_read, frame = vid.read()
if is_read:
if num % interval == 0:
file_name = '%s_%d' % (name, num)
print(output_path + str(file_name) + '.jpg')
cv2.imwrite(output_path + str(file_name) + '.jpg', frame)
# 00000111.jpg 代表第111帧
cv2.waitKey(1)
# print(file_name, '.jpg')
num += 1
else:
break
要求使用cutie生成裁剪的图片和xml,但是需要设置两个pad参数,一个为跟踪时pad的大小,一个为裁剪的大小。
pad_box 是pad为200的裁剪框,crop_pad_box 是pad为600的裁剪框,要保存600的裁剪框
创建一个与原始图像尺寸相同的空白掩码数组。
根据给定的裁剪框,在空白掩码数组中填充原始掩码数据,形成一个带有裁剪框的掩码图像。
根据另一个给定的裁剪区域,从带有裁剪框的掩码图像中截取出真正需要的裁剪后的掩码数据。
crop_input_frame, crop_frame_shape, pad_box = crop_and_pad_img_no_resize(frame, last_bbox,img_shape, input_shape,pad)
input_frame = cv2.resize(crop_input_frame, (tracker_inp_w, tracker_inp_h))
t22 = time.time()
frame_torch = image_to_torch(input_frame, device=device)
crop_prob = processor.step(frame_torch)
crop_mask = torch_prob_to_numpy_mask(crop_prob)
crop_mask_obj1 = (crop_mask == 1).astype(np.uint8)
crop_mask_obj2 = (crop_mask == 2).astype(np.uint8)
crop_mask_obj1 = cv2.resize(crop_mask_obj1, (crop_frame_shape[1], crop_frame_shape[0]), cv2.INTER_LINEAR)
crop_mask_obj2 = cv2.resize(crop_mask_obj2, (crop_frame_shape[1], crop_frame_shape[0]), cv2.INTER_LINEAR)
crop_mask = cv2.resize(crop_mask, (crop_frame_shape[1], crop_frame_shape[0]), cv2.INTER_LINEAR)
mask_c = np.zeros([crop_frame_shape[0], crop_frame_shape[1]], dtype=np.uint8)
# print(crop_mask_obj1.shape)
# print(crop_frame_shape[0], crop_frame_shape[1])
mask_c[crop_mask_obj1 == 1] = 1
mask_c[crop_mask_obj2 == 1] = 2
#裁剪图片的pad
crop_crop_input_frame, crop_crop_frame_shape, crop_pad_box = crop_and_pad_img_no_resize(frame, last_bbox, img_shape, input_shape,crop_pad)
mask_cr = np.zeros([img_h, img_w], dtype=np.uint8)
pad_x1, pad_y1, pad_x2, pad_y2 = pad_box
mask_cr[pad_y1: pad_y2, pad_x1: pad_x2] = mask_c
pad_x1, pad_y1, pad_x2, pad_y2 = crop_pad_box
pad_mask = mask_cr[pad_y1: pad_y2, pad_x1: pad_x2]
# if((current_frame_index - 1) % 10 == 0):
crop(video_dir, pad_mask, crop_crop_input_frame, num_objects, current_frame_index)
每个视频用cutie生成img和xml,以视频名为文件夹,每个项目20个
跟踪预测框 检测框 做iou,保留最大的那个
多人做匈牙利匹配,匹配的依据是计算目标边界框与检测结果边界框之间的IoU值,并使用匈牙利匹配算法找到最佳的匹配结果。
def crop(video_dir, crop_mask, input_frame, crop_input_frame, num_objects, current_frame_index,pad_box):
save_root = r'C:\Users\k167\Desktop\dataset\time_plus_bsu_annotation'
mkdir(save_root)
save_dir = video_dir.replace('time_plus_bsu','time_plus_bsu_annotation')#C:\Users\k167\Desktop\dataset\time_plus_bsu_crop\tiaoshui10msynchronize
mkdir(save_dir)
save_img = os.path.join(save_dir, os.path.basename(video_dir).replace('.mp4', '_%s.jpg'%(current_frame_index)))
xml_path = r'D:\szj\time_plus-main\testpic\biaoqiang2_tokyo2020_35_mp4_50.xml'
save_xml_path = save_img.replace('.jpg', '.xml').replace('.png', '.xml')
bboxes = []#[[1205, 495, 1279, 910], [904, 448, 1004, 993]](双目标为此格式,不是这个数据)
ious = []#[[-0.3312384473197782, -0.9452054794520548, -0.0, -0.0], [-0.9546765249537893, -0.0, -0.0, -0.0]]
if num_objects > 1:
det_results = detector.run(crop_input_frame, detector.opt.vis_thresh)
for det_res in det_results:#['person', 0.4215676486492157, [675, 680, 742, 890], [], [], [], []]
det_res[2][0] += pad_box[0]
det_res[2][1] += pad_box[1]
det_res[2][2] += pad_box[0]
det_res[2][3] += pad_box[1]
det_boxs = [i[2] for i in det_results]#[[904, 447, 1004, 988], [1205, 495, 1278, 910], [756, 765, 830, 990], [675, 680, 742, 890]]
for obj in range(num_objects):
mask = np.array([crop_mask == obj + 1], dtype=np.uint8)
bbox = list(get_bbox(mask.squeeze(0)))#[908, 450, 1004, 989]
iou = []#[-0.9546765249537893, -0.0, -0.0, -0.0]
for det_box in det_boxs:
iou.append(cal_iou(bbox, det_box))
ious.extend([iou])
matches = linear_sum_assignment(ious)#([0 1], [1 0])
for matche in matches[1]:
if (det_boxs[matche][3] - det_boxs[matche][1]) > 0 and (det_boxs[matche][2] - det_boxs[matche][0]) > 0:
bboxes.extend([det_boxs[matche]])
else:
for obj in range(num_objects):
mask = np.array([crop_mask == obj + 1], dtype=np.uint8)
bbox = list(get_bbox(mask.squeeze(0)))
if (bbox[3] - bbox[1]) > 0 and (bbox[2] - bbox[0]) > 0:
bboxes.append(bbox)
if(len(bboxes) > 0):
cv2.imwrite(save_img, input_frame)
change_xml(input_frame, bboxes, xml_path, save_xml_path)
sam分割,cv2.selectROI选框
def init_roi(frame, idx):
print("init{}: init roi....".format(idx))
print("init{}: frame shape: {}".format(idx, frame.shape))
cv2.namedWindow("select", 0)
res = cv2.selectROI("select", frame.copy())
cv2.destroyWindow("select")
print("init{}: select: {}".format(idx, res))
x, y, w, h = res
assert w > 1 and h > 1, "please select rectangle"
x1, y1, x2, y2 = x, y, x + w, y + h
bbox = [x1, y1, x2, y2]
print("init{} bbox: {}".format(idx, bbox))
return bbox
class SAMNet(object):
def __init__(self):
from segment_anything import sam_model_registry, SamPredictor
# model_type, model_path = "vit_b", r"D:\work\hzy\github\segment-anything\sam_vit_b_01ec64.pth"
# model_type, model_path = "vit_l", r"D:\work\hzy\github\segment-anything\sam_vit_l_0b3195.pth"
model_type, model_path = "vit_h", r"D:\szj\sam_weight\sam_vit_h_4b8939.pth"
print("{} weight {}".format(self.__class__.__name__, model_path))
sam = sam_model_registry[model_type](checkpoint=model_path)
self.predictor = SamPredictor(sam.to(device="cuda"))
print("init {} done...".format(self.__class__.__name__))
self.img_size = 384
def run(self, image, input_boxes):
t1 = time.time()
self.predictor.set_image(image)
masks, score, logist = self.predictor.predict(point_coords=None, point_labels=None,
box=np.array(input_boxes), multimask_output=True)
mask, logit = masks[np.argmax(score)], logist[np.argmax(score), :, :]
print('masks {}, score {}'.format(masks.shape, score.tolist()))
t2 = time.time()
print("{} time cost {}".format(self.__class__.__name__, t2 - t1))
mask_color = 3
mask_alpha = 0.7
contour_color = 1
contour_width = 5
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color,
contour_width)
# painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
segment = SAMNet()
bbox = init_roi(img, obj)
mask_np, _, seg_res = segment.run(frame, bbox)
mask_np = mask_np.astype(np.uint8)
cv2.namedWindow("init", 0)
cv2.imshow("init", mask_np * int(255 / max(1, mask_np.max())))
key = cv2.waitKey(0)
if key == 27:
break
cv2.destroyWindow("init")
df -h
找开源项目的一些途径
• https://github.com/trending/
• https://github.com/521xueweihan/HelloGitHub
• https://github.com/ruanyf/weekly
• https://www.zhihu.com/column/mm-fe
特殊的查找资源小技巧-常用前缀后缀
• 找百科大全 awesome xxx
• 找例子 xxx sample
• 找空项目架子 xxx starter / xxx boilerplate
• 找教程 xxx tutorial