1.训练一个检测器(bing),要将alfw的数据,生成以下格式的正样本标注格式。
alfw本来是sqlit3的数据库格式,之前生成了.txt的。
利用该.txt文件生成标注文件
import os
import cv2
data_base = './JPEGImages/'
face_file = open('face_rect.txt')
face_file.next()
if not os.path.exists('Annotations/flickr/0'):
os.makedirs('Annotations/flickr/0')
if not os.path.exists('Annotations/flickr/2'):
os.makedirs('Annotations/flickr/2')
if not os.path.exists('Annotations/flickr/3'):
os.makedirs('Annotations/flickr/3')
for line in face_file:
spline = line.split()
imageid = spline[0]
path = spline[1]
xmin = spline[2]
ymin = spline[3]
width = spline[4]
height = spline[5]
xmax = int(xmin) + int(width)
ymax = int(ymin) + int(height)
filename = path.split('.')[0]
img = cv2.imread(data_base + path)
img_width = img.shape[1]
img_height = img.shape[0]
if int(xmin) <= 0:
print path, 'xmin', xmin
xmin = '1'
if int(ymin) <= 0:
print path, 'ymin', ymin
ymin = '1'
if xmax >= img_width:
print path, 'xmax', xmax, 'img_width', img_width, 'img_height', img_height
xmax = img_width-1
if ymax >= img_height:
print path, 'ymax', ymax
ymax = img_height-1
if not os.path.exists('Annotations/'+filename+'.yml'):
file = open('Annotations/'+filename+'.yml', 'w')
print >> file, '%YAML:1.0\n'
print >> file, 'annotation:'
print >> file, ' folder: ALFW'
print >> file, ' filename: \"%s\"'%path
print >> file, ' source: {id: %s}' %imageid
print >> file, ' owner: {name: zhuqian}'
print >> file, ' size: {width: \'%s\', height: \'%s\', depth: \'3\'}'%(img_width, img_height)
print >> file, ' segmented: \'0\''
print >> file, ' object:'
print >> file, ' - bndbox: {xmin: \'%s\', ymin: \'%s\', xmax: \'%s\', ymax: \'%s\'}'%(xmin, ymin, xmax, ymax)
print >> file, ' name: face'
print >> file, ' pose: Left'
print >> file, ' truncated: \'1\''
print >> file, ' difficult: \'0\''
else:
file = open('Annotations/'+filename+'.yml', 'a')
print >> file, ' - bndbox: {xmin: \'%s\', ymin: \'%s\', xmax: \'%s\', ymax: \'%s\'}'%(xmin, ymin, xmax, ymax)
print >> file, ' name: face'
print >> file, ' pose: Left'
print >> file, ' truncated: \'1\''
print >> file, ' difficult: \'0\''
#print path
file.close()
face_file.close()
2. 随机生成训练测试样本。
import os
import random
images = []
with open('face_rect.txt','r') as face_file:
face_file.next()
for line in face_file:
img_name = line.split()[1].split('.')[0]
images.append(img_name)
random.shuffle(images)
num = len(images)
with open('ImageSets/Main/train.txt','w') as train_file:
for i in xrange(0, num/4):
print >> train_file, images[i]
with open('ImageSets/Main/test.txt','w') as test_file:
for i in xrange(num/4+1, num):
print >> test_file, images[i]
with open('ImageSets/Main/class.txt','w') as class_file:
print >> class_file, 'face'
3. 将不是jpg格式的图片转成jpg。
import os
import cv2
data_base = './JPEGImages/'
face_file = open('face_rect.txt')
face_file.next()
for line in face_file:
spline = line.split()
path = spline[1]
filename = path.split('.')[0]
ext = path.split('.')[1]
if ext != 'jpg':
print path
img = cv2.imread(data_base + path)
cv2.imwrite(data_base + filename + '.jpg', img)
face_file.close()