train是训练集,val是训练过程中的测试集,是为了让你在边训练边看到训练的结果,及时判断学习状态。test就是训练模型结束后,用于评价模型结果的测试集。只有train就可以训练,val不是必须的,比例也可以设置很小。
验证数据集可以理解为训练数据集的一块
制作图书馆数据集代码如下:
### Data Format for Semantic Segmentation
The raw data will be processed by generator shell scripts. There will be two subdirs('train' & 'val')
```
train or val dir {
image: contains the images for train or val.
label: contains the label png files(mode='P') for train or val.
mask: contains the mask png files(mode='P') for train or val.
}
```
"""
-*- coding: utf-8 -*-
author: Hao Hu
@date 2022/1/20 11:02 AM
"""
import cv2
import numpy as np
from matplotlib import pyplot as plt
import os.path as osp
import os
from tqdm import tqdm
from PIL import Image
import PIL
from concurrent.futures import ThreadPoolExecutor
def grab_cut(img_path):
"""使用了grab_cut算法获得物体和背景轮廓"""
img_ori = cv2.imread(img_path)
# 将img二值化
retVal, image = cv2.threshold(img_ori, 50, 100, cv2.THRESH_BINARY)
mask = np.zeros(image.shape[:2], np.uint8)
bgdModel = np.zeros((1, 65), np.float64)
fgdModel = np.zeros((1, 65), np.float64)
ix = int(img_ori.shape[0] / 22)
iy = int(img_ori.shape[1] / 20)
w = iy * 20
h = ix * 22
rect = (ix, iy, int(w), int(h))
# cv2.rectangle(img, (ix*2, iy*3), (int(w*0.9), int(h*0.9)), (0, 255, 0), 2)
# 默认几个点作为物体和背景像素点
# (ix*15,iy*26),(ix*21,iy*15),(ix*21,iy*10)为背景像素点
cv2.circle(mask, (ix*15, iy*26), 15, [0,0,0], -1)
cv2.grabCut(image, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
mask2[ix * 21, iy * 19] = 1
#plt.imshow(mask2), plt.colorbar(), plt.show()
img = image * mask2[:, :, np.newaxis]
return img,image,mask2,img_ori
def get_mask_box(mask):
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = list(contours)
contours.sort(key=lambda x: cv2.contourArea(x), reverse=True)
cnt = cv2.approxPolyDP(contours[0], epsilon=100, closed=True)
cnt = cv2.minAreaRect(cnt)
box = np.int0(cv2.boxPoints(cnt))
return mask, box
def imwrite_the_label_img(ori_folder,end_folder_path,img_NAME):
img_path = osp.join(ori_folder,img_NAME)
img,image,mask,img_ori = grab_cut(img_path)
_, box=get_mask_box(mask)
re = cv2.drawContours(image.copy(), [box], 0, (0, 255, 0), -1)
end_path = osp.join(end_folder_path, img_NAME[:-2]+'.png')
cv2.imwrite((end_path), re)
# 将图片转为model = P
re = PIL.Image.open(end_path)
re = re.convert('P')
re.save(end_path)
if __name__ == '__main__':
ori_folder = '/cloud_disk/users/huh/dataset/lib_dataset/train/image'
img_list = os.listdir(ori_folder)
end_folder_path = '/cloud_disk/users/huh/dataset/lib_dataset/train/label'
executor = ThreadPoolExecutor(max_workers=100) # 最大线程数量
for img_NAME in tqdm(img_list):
executor.map(imwrite_the_label_img(ori_folder,end_folder_path,img_NAME))