深度有趣 | 27 服饰关键点定位

本文介绍了如何利用CPM(Convolutional Pose Machines)在TensorFlow中实现服饰关键点定位,详细阐述了任务原理、数据来源、模型实现过程,并展示了训练与测试结果。
摘要由CSDN通过智能技术生成

简介

介绍如何使用CPM(Convolutional Pose Machines)实现服饰关键点定位

原理

关键点定位是一类常见而有用的任务,某种意义上可以理解为一种特征工程

  • 人脸关键点定位,可用于人脸识别、表情识别
  • 人体骨骼关键点定位,可用于姿态估计
  • 手部关键点定位,可用于手势识别

输入是一张图片,输出是每个关键点的x、y坐标,一般会归一化到0~1区间中,所以可以理解为回归问题

但是直接对坐标值进行回归会导致较大误差,更好的做法是输出一个低分辨率的热图,使得关键点所在位置输出较高响应,而其他位置则输出较低响应

关键点检测热图示例

CPM(Convolutional Pose Machines)的基本思想是使用多个级联的stage,每个stage包含多个CNN并且都输出热图

通过最小化每个stage的热图和ground truth之间的差距,从而得到越来越准确的关键点定位结果

CPM模型结构图

数据

使用天池FashionAI全球挑战赛提供的数据,http://fashionai.alibaba.com/

其中服饰关键点定位赛题提供的训练集包括4W多张图片,测试集包括将近1W张图片

每张图片都指定了对应的服饰类别,共5类:上衣(blouse)、外套(outwear)、连身裙(dress)、半身裙(skirt)、裤子(trousers)

不同服饰类别对应的关键点

训练集还提供了每张图片对应的24个关键点的标注,包括x坐标、y坐标、是否可见三项信息,但并不是每类服饰都有24个关键点

关于以上数据的更多介绍可以参考以下文章,https://zhuanlan.zhihu.com/p/34928763

为了简化问题,以下仅使用dress类别的训练集数据训练CPM模型

实现

加载库

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
from imageio import imread, imsave
import os
import glob
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

加载训练集和测试集

train = pd.read_csv(os.path.join('data', 'train', 'train.csv'))
train_warm = pd.read_csv(os.path.join('data', 'train_warm', 'train_warm.csv'))
test = pd.read_csv(os.path.join('data', 'test', 'test.csv'))
print(len(train), len(train_warm), len(test))

columns = train.columns
print(len(columns), columns)

train['image_id'] = train['image_id'].apply(lambda x:os.path.join('train', x))
train_warm['image_id'] = train_warm['image_id'].apply(lambda x:os.path.join('train_warm', x))
train = pd.concat([train, train_warm])
del train_warm
train.head()

仅保留dress类别的数据

train = train[train.image_category == 'dress']
test = test[test.image_category == 'dress']
print(len(train), len(test))

拆分标注信息中的x坐标、y坐标、是否可见

for col in columns:
    if col in ['image_id', 'image_category']:
        continue
    train[col + '_x'] = train[col].apply(lambda x:float(x.split('_')[0]))
    train[col + '_y'] = train[col].apply(lambda x:float(x.split('_')[1]))
    train[col + '_s'] = train[col].apply(lambda x:float(x.split('_')[2]))
    train.drop([col], axis=1, inplace=True)
train.head()

将x坐标和y坐标进行归一化

features = [
    'neckline_left', 'neckline_right', 'center_front', 'shoulder_left', 'shoulder_right', 
    'armpit_left', 'armpit_right', 'waistline_left', 'waistline_right', 
    'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'hemline_left', 'hemline_right']

train = train.to_dict('records')
for 
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值