视线估计Gaze-Estimation PFLD实现
gaze-estimation问题主要是数据集标注及其困难,针对最近开源的一个数据集,实验回归的方式进行了训练。
整个项目源码:https://github.com/ycdhqzhiai/Gaze-PFLD
1.数据集
- 数据集预处理
这里将其转换为Json格式,只保留landmarks和gaze-vector,其他标注信息没有用到
import os
import cv2
import glob
import numpy as np
import argparse
import json
##注意该代码只能存9999张图片,如果按帧率为30计算,大概最多只能存55分钟视频
def parse_args():
parser = argparse.ArgumentParser(description="EyeGaze datasets")
parser.add_argument("--video_path", type=str, default='DIKABLISVIDEOS', help='videos path')
parser.add_argument("--annotations",type=str, default='ANNOTATIONS', help='videos label path including gaze_vec iris_lm_2D lid_lm_2D pupil_lm_2D')
parser.add_argument("--images",type=str, default='images', help='save_path')
parser.add_argument("--draw_img",type=str, default='draw_img', help='save_path')
parser.add_argument("--blind",type=str, default='blind', help='save_path')
parser.add_argument("--json",type=str, default='json', help='save_path')
args = parser.parse_args()
return args
def mkd(path):
if not os.path.exists(path):
os.makedirs(path)
def judge_exists(path):
if os.path.exists(path):
return False
return True
def log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements):
b1 = judge_exists(agaze_vec)
b2 = judge_exists(airis_lm_2D)
b3 = judge_exists(alid_lm_2D)
b4 = judge_exists(apupil_lm_2D)
b5 = judge_exists(aeye_movements)
if b1:
print('gaze_vec not found!!! EXIT')
if b2:
print('iris_lm_2D not found!!! EXIT')
if b3:
print('lid_lm_2D not found!!! EXIT')
if b4:
print('pupil_lm_2D not found!!! EXIT')
if b5:
print('eye_movements not found!!! EXIT')
if b1 or b2 or b3 or b4 or b5:
return False
return True
def main():
args = parse_args()
video_list = glob.glob(os.path.join(args.video_path, '*.mp4'))
for video in video_list:
name = os.path.split(video)[1]
# if not '5_2' in name:
# continue
images_dir = os.path.join(args.images, name)
draw_img_dir = os.path.join(args.draw_img, name)
blind_dir = os.path.join(args.blind, name)
json_dir = os.path.join(args.json, name)
mkd(images_dir)
mkd(draw_img_dir)
mkd(blind_dir)
mkd(json_dir)
agaze_vec = os.path.join(args.annotations, name+'gaze_vec.txt')
airis_lm_2D = os.path.join(args.annotations, name+'iris_lm_2D.txt')
alid_lm_2D = os.path.join(args.annotations, name+'lid_lm_2D.txt')
apupil_lm_2D = os.path.join(args.annotations, name+'pupil_lm_2D.txt')
aeye_movements = os.path.join(args.annotations, name+'eye_movements.txt')
flage = log(agaze_vec, airis_lm_2D, alid_lm_2D, apupil_lm_2D, aeye_movements)
if not flage:
exit()
with open(agaze_vec, 'r') as fgaze_vec:
lgaze_vec = fgaze_vec.readlines()[1:]
with open(airis_lm_2D, 'r') as firis_lm_2D:
liris_lm_2D = firis_lm_2D.readlines()[1:]
with open(alid_lm_2D, 'r') as flid_lm_2D:
llid_lm_2D = flid_lm_2D.readlines()[1:]
with open(apupil_lm_2D, 'r') as fpupil_lm_2D:
lpupil_lm_2D = fpupil_lm_2D.readlines()[1:]
with open(aeye_movements, 'r') as feye_movements:
leye_movements = feye_movements.readlines()[3:]
cap = cv2.VideoCapture(video)
num = 0
while 1:
ret, frame = cap.read()
if not ret:
break
src = frame.copy()
save_src = '{}/{}_{:0>5d}.jpg'.format(images_dir, name[:-4], num)
save_draw = '{}/{}_{:0>5d}.jpg'.format(draw_img_dir, name[:-4], num)
save_blind = '{}/{}_{:0>5d}.jpg'.format(blind_dir, name[:-4], num)
save_json = '{}/{}_{:0>5d}.json'.format(json_dir, name[:-4], num)
eye_movements = leye_movements[num].strip()[2:3]
gaze_vec = np.array([float(x) for x in lgaze_vec[num].strip().split(';')[1:3]])
iris_lm_2D = np.array([float(x) for x in liris_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#虹膜,中间那块
lid_lm_2D = np.array([float(x) for x in llid_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#眼睑,最外面那块
pupil_lm_2D = np.array([float(x) for x in lpupil_lm_2D[num].strip().split(';')[2:-1]]).reshape(-1,2)#瞳孔,最里面那块
num += 1
if eye_movements == '1':
continue
eye_c = np.mean(pupil_lm_2D, axis=0).astype(int)
for index in range(iris_lm_2D.shape[0]):
x_y = iris_lm_2D[index]
cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,255,0),-1) # 绿色
for index in range(lid_lm_2D.shape[0]):
x_y = lid_lm_2D[index]
cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (255,0,0),-1) # 蓝色
for index in range(pupil_lm_2D.shape[0]):
x_y = pupil_lm_2D[index]
cv2.circle(frame, (int(x_y[0]), int(x_y[1])), 1, (0,0,255),-1) # 红色
cv2.circle(frame, tuple(eye_c), 1, (255,255,255),-1)
cv2.line(frame, tuple(eye_c), tuple(eye_c+(gaze_vec*100).astype(int)), (0,255,255), 1) # 黄色
label_dict = {'gaze_vec':gaze_vec.tolist(), 'iris_lm_2D':iris_lm_2D.tolist(), 'lid_lm_2D':lid_lm_2D.tolist(), 'pupil_lm_2D':pupil_lm_2D.tolist()}
if -1 in gaze_vec:
cv2.imwrite(save_blind, frame)
with open(save_json.replace('json\\', 'blind\\'), 'w') as dump_f:
json.dump(label_dict,dump_f)
else:
if num % 3 == 0:
cv2.imwrite(save_src, src)
with open(save_json, 'w') as dump_f:
json.dump(label_dict,dump_f)
cv2.imwrite(save_draw, frame)
if __name__ == '__main__':
main()
2.训练
使用PFLD来训练gaze-estimation,PFLDInference骨干网络用来预测landmarks,AuxiliaryNet网络用来预测gaze-vector。
- dataloder
def preprocess_unityeyes_image(img, json_data, datasets, input_width, input_height):
ow = 160
oh = 96
# Prepare to segment eye image
ih, iw = img.shape[:2]
ih_2, iw_2 = ih/2.0, iw/2.0
heatmap_w = int(ow/2)
heatmap_h = int(oh/2)
#img = cv2.resize(im, (im.shape[1]*3, im.shape[0]*3))
#img = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
if datasets == 'B':
gaze = np.array(json_data['gaze'])
landmarks = np.array(json_data['landmarks'])
left_corner = landmarks[0]
right_corner = landmarks[4]
eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
eye_middle = landmarks[24].astype(int)
elif datasets == 'E':
gaze = np.array(json_data['gaze_vec'])
left_corner = np.array(json_data['lid_lm_2D'])[0]
right_corner = np.array(json_data['lid_lm_2D'])[33]
eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
eye_middle = np.mean([np.amin(np.array(json_data['iris_lm_2D']), axis=0), np.amax(np.array(json_data['iris_lm_2D']), axis=0)], axis=0)
landmarks = np.concatenate((np.array(json_data['lid_lm_2D']), np.array(json_data['iris_lm_2D']), np.array(json_data['pupil_lm_2D']), eye_middle.reshape(1,2)))
else:
print('UnityEyes do not write!!!')
exit()
crop_img, lad = get_img(img, landmarks)
crop_img = cv2.resize(crop_img, (input_width,input_height))
# if 1:
# print(crop_img.shape)
# for (x, y) in lad:
# color = (0, 255, 0)
# cv2.circle(crop_img, (int(round(x*crop_img.shape[1])), int(round(y*crop_img.shape[0]))), 1, color, -1, lineType=cv2.LINE_AA)
# #crop_img = cv2.resize(crop_img, (160,96))
# cv2.imshow('c', crop_img)
# cv2.waitKey(0)
# exit()
return crop_img, lad, gaze
class EyesDataset(data.Dataset):
def __init__(self, datasets, dataroot, transforms=None, input_width=160, input_height=112):
self.dataroot = dataroot
self.datasets = datasets
self.input_width = input_width
self.input_height = input_height
self.transforms = transforms
if datasets == 'U':
self.img_paths = glob.glob(os.path.join(dataroot, 'UnityEyes/images', '/*.jpg'))
elif datasets == 'E':
self.img_paths = glob.glob(os.path.join(dataroot, 'Eye200W/images', '/*.jpg'))
elif datasets == 'B':
self.img_paths = glob.glob(os.path.join(dataroot, 'BL_Eye/images', '/*.jpg'))
self.img_paths = sorted(self.img_paths)
self.json_paths = []
for img_path in self.img_paths:
json_files = img_path.replace('images', 'json').replace('.jpg', '.json')
self.json_paths.append(json_files)
def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()
full_img = cv2.imread(self.img_paths[index])
with open(self.json_paths[index]) as f:
json_data = json.load(f)
eye, landmarks, gaze = preprocess_unityeyes_image(full_img, json_data, self.datasets, self.input_width, self.input_height)
if self.transforms:
eye = self.transforms(eye)
return eye, landmarks, gaze
def __len__(self):
return len(self.img_paths)
- model
class Gaze_PFLD(nn.Module):
def __init__(self):
super(Gaze_PFLD, self).__init__()
self.lad = PFLDInference()
self.gaze = AuxiliaryNet()
def forward(self, x):
features, landmark = self.lad(x)
gaze = self.gaze(features)
return landmark, gaze
- loss
class PFLDLoss(nn.Module):
def __init__(self):
super(PFLDLoss, self).__init__()
self.gaze_loss = nn.MSELoss()
def forward(self, landmark_gt,
landmarks, gaze_pred, gaze):
lad_loss = wing_loss(landmark_gt, landmarks)
gaze_loss = self.gaze_loss(gaze_pred, gaze)
return gaze_loss*1000, lad_loss
def wing_loss(y_true, y_pred, w=10.0, epsilon=2.0, N_LANDMARK=51):
y_pred = y_pred.reshape(-1, N_LANDMARK, 2)
y_true = y_true.reshape(-1, N_LANDMARK, 2)
x = y_true - y_pred
c = w * (1.0 - math.log(1.0 + w / epsilon))
absolute_x = torch.abs(x)
losses = torch.where(w > absolute_x,
w * torch.log(1.0 + absolute_x / epsilon),
absolute_x - c)
loss = torch.mean(torch.sum(losses, axis=[1, 2]), axis=0)
return loss
3.demo
import argparse
import numpy as np
import cv2
import torch
import torchvision
from models.pfld import PFLDInference, AuxiliaryNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
checkpoint = torch.load(args.model_path, map_location=device)
print(checkpoint.keys())
pfld_backbone = PFLDInference().to(device)
auxiliarynet = AuxiliaryNet().to(device)
pfld_backbone.load_state_dict(checkpoint['pfld_backbone'])
auxiliarynet.load_state_dict(checkpoint["auxiliarynet"])
pfld_backbone.eval()
auxiliarynet.eval()
pfld_backbone = pfld_backbone.to(device)
auxiliarynet = auxiliarynet.to(device)
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()])
img = cv2.imread('5.png')
img = cv2.resize(img, (img.shape[1]*1, img.shape[0]*1))
height, width = img.shape[:2]
input = cv2.resize(img, (160,112))
input = transform(input).unsqueeze(0).to(device)
features, landmarks = pfld_backbone(input)
gaze = auxiliarynet(features)
pre_landmark = landmarks[0]
#print(pre_landmark.shape)
pre_landmark = pre_landmark.cpu().detach().numpy().reshape(
-1, 2) * [width, height]
gaze = gaze.cpu().detach().numpy()[0]
c_pos = pre_landmark[-1,:]
cv2.line(img, tuple(c_pos.astype(int)), tuple(c_pos.astype(int)+(gaze*400).astype(int)), (0,255,0), 1)
for (x, y) in pre_landmark.astype(np.int32):
cv2.circle(img, (x, y), 1, (0, 0, 255))
cv2.imshow('gaze estimation', img)
cv2.imwrite('gaze.jpg', img)
cv2.waitKey(0)
def parse_args():
parser = argparse.ArgumentParser(description='Testing')
parser.add_argument('--model_path',
default="./checkpoint/snapshot/checkpoint_epoch_13.pth.tar",
type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)
效果图
3.export onnx
# from __future__ import absolute_import
# from __future__ import division
# from __future__ import print_function
import argparse
import sys
import time
from models.pfld import Gaze_PFLD
import torch
import torch.nn as nn
import models
# def load_model_weight(model, checkpoint):
# state_dict = checkpoint['model_state_dict']
# # strip prefix of state_dict
# if list(state_dict.keys())[0].startswith('module.'):
# state_dict = {k[7:]: v for k, v in checkpoint['model_state_dict'].items()}
# model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
# # check loaded parameters and created model parameters
# for k in state_dict:
# if k in model_state_dict:
# if state_dict[k].shape != model_state_dict[k].shape:
# print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(
# k, model_state_dict[k].shape, state_dict[k].shape))
# state_dict[k] = model_state_dict[k]
# else:
# print('Drop parameter {}.'.format(k))
# for k in model_state_dict:
# if not (k in state_dict):
# print('No param {}.'.format(k))
# state_dict[k] = model_state_dict[k]
# model.load_state_dict(state_dict, strict=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default="./checkpoint/snapshot/checkpoint.pth.tar", help='weights path') # from yolov5/models/
parser.add_argument('--img-size', nargs='+', type=int, default=[112, 160], help='image size') # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
device = "cpu"
print("=====> load pytorch checkpoint...")
checkpoint = torch.load(opt.weights, map_location=torch.device('cpu'))
nstack = checkpoint['nstack']
nfeatures = checkpoint['nfeatures']
nlandmarks = checkpoint['nlandmarks']
net = Gaze_PFLD().to(device)
net.load_state_dict(checkpoint['gaze_pfld'])
img = torch.zeros(1, 1, *opt.img_size).to(device)
print(img.shape)
landmarks, gaze = net.forward(img)
f = opt.weights.replace('.pth.tar', '.onnx') # filename
torch.onnx.export(net, img, f,export_params=True, verbose=False, opset_version=12, input_names=['inputs'])
# # ONNX export
try:
import onnx
from onnxsim import simplify
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = opt.weights.replace('.pth.tar', '.onnx') # filename
torch.onnx.export(net, img, f, verbose=False, opset_version=11, input_names=['images'],
output_names=['output'])
# Checks
onnx_model = onnx.load(f) # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, f)
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
except Exception as e:
print('ONNX export failure: %s' % e)