人脸关键点抠图代码

效果图:

 原图                                                    裁切出来的人脸

网络模型来自我的另一篇博客中高精度轻量级人脸关键点:

insightface_landmarks

轻量级高精度人脸关键点推荐_jacke121的专栏-CSDN博客

import argparse
import cv2
import sys
import numpy as np
import os
import mxnet as mx
import datetime
from skimage import transform as trans, measure


def square_crop(im, S):
    if im.shape[0] > im.shape[1]:
        height = S
        width = int(float(im.shape[1]) / im.shape[0] * S)
        scale = float(S) / im.shape[0]
    else:
        width = S
        height = int(float(im.shape[0]) / im.shape[1] * S)
        scale = float(S) / im.shape[1]
    resized_im = cv2.resize(im, (width, height))
    det_im = np.zeros((S, S, 3), dtype=np.uint8)
    det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
    return det_im, scale


def transform(data, center, output_size, scale, rotation):
    scale_ratio = scale
    rot = float(rotation) * np.pi / 180.0
    #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
    t1 = trans.SimilarityTransform(scale=scale_ratio)
    cx = center[0] * scale_ratio
    cy = center[1] * scale_ratio
    t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
    t3 = trans.SimilarityTransform(rotation=rot)
    t4 = trans.SimilarityTransform(translation=(output_size / 2,
                                                output_size / 2))
    t = t1 + t2 + t3 + t4
    M = t.params[0:2]
    cropped = cv2.warpAffine(data,
                             M, (output_size, output_size),
                             borderValue=0.0)
    return cropped, M


def trans_points2d(pts, M):
    new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
    for i in range(pts.shape[0]):
        pt = pts[i]
        new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
        new_pt = np.dot(M, new_pt)
        #print('new_pt', new_pt.shape, new_pt)
        new_pts[i] = new_pt[0:2]

    return new_pts


def trans_points3d(pts, M):
    scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
    #print(scale)
    new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
    for i in range(pts.shape[0]):
        pt = pts[i]
        new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
        new_pt = np.dot(M, new_pt)
        #print('new_pt', new_pt.shape, new_pt)
        new_pts[i][0:2] = new_pt[0:2]
        new_pts[i][2] = pts[i][2] * scale

    return new_pts


def trans_points(pts, M):
    if pts.shape[1] == 2:
        return trans_points2d(pts, M)
    else:
        return trans_points3d(pts, M)


class Handler:
    def __init__(self, prefix, epoch, im_size=192, det_size=224, ctx_id=0):
        print('loading', prefix, epoch)
        if ctx_id >= 0:
            ctx = mx.gpu(ctx_id)
        else:
            ctx = mx.cpu()
        image_size = (im_size, im_size)

        self.det_size = det_size
        sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
        all_layers = sym.get_internals()
        sym = all_layers['fc1_output']
        self.image_size = image_size
        model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
        model.bind(for_training=False,
                   data_shapes=[('data', (1, 3, image_size[0], image_size[1]))
                                ])
        model.set_params(arg_params, aux_params)
        self.model = model
        self.image_size = image_size

    def get(self, rimg, get_all=False):
        out = []

        w,h=rimg.shape[:2]

        input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)
        center = w / 2, h / 2
        rotate = 0
        _scale = self.image_size[0] * 2 / 3.0 / max(w, h)
        rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
        rimg = np.transpose(rimg, (2, 0, 1))  #3*112*112, RGB
        input_blob[0] = rimg
        data = mx.nd.array(input_blob)
        db = mx.io.DataBatch(data=(data, ))
        self.model.forward(db, is_train=False)
        pred = self.model.get_outputs()[-1].asnumpy()[0]
        if pred.shape[0] >= 3000:
            pred = pred.reshape((-1, 3))
        else:
            pred = pred.reshape((-1, 2))
        pred[:, 0:2] += 1
        pred[:, 0:2] *= (self.image_size[0] // 2)
        if pred.shape[1] == 3:
            pred[:, 2] *= (self.image_size[0] // 2)

        return pred


if __name__ == '__main__':
    handler = Handler('2d106det', 0, ctx_id=-1, det_size=320)
    im = cv2.imread(r'd:\qinlan4.jpg')
    im = cv2.imread(r'G:\data\bianshi\100w_huochezhan\tiaotu\data_51\1e8f6_20180104062753977_IDCard.bmp')

    im=cv2.resize(im,(192,192))
    tim = im.copy()
    pred = handler.get(im, get_all=True)
    color = (200, 160, 75)
    font = cv2.FONT_HERSHEY_SIMPLEX

    new_points=[]
    pred = np.round(pred).astype(np.int)

    height=pred[9][1]-pred[1][1]

    for i in range(pred.shape[0]):
        if i>=33 and i<=42:
            continue
        if i>=50 and i<=96:
            continue
        if i>=44 and i<=47:
            continue
        if i>=97 and i<=100:
            continue
        if i==103 or i==102:
            continue
        if i in [43,48,49,50,101,104,105]:
            pred[i][1]-=height
        p = tuple(pred[i])
        new_points.append(pred[i])
        # tim = cv2.putText(tim,str(i), (int(p[0]+5), int(p[1])), font, 0.4,
        #                       (0, 0, 255), 1)
        # cv2.circle(tim, p, 1, color, 1, cv2.LINE_AA)

    from functools import reduce
    import operator
    import math

    coords=new_points
    center = tuple(map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), coords), [len(coords)] * 2))
    new_points=sorted(coords, key=lambda coord: (-135 - math.degrees(
        math.atan2(*tuple(map(operator.sub, coord, center))[::-1]))) % 360, reverse=True)

    img_origin = tim.copy()

    # for index, pred in enumerate(new_points):
    #     pred = np.round(pred).astype(np.int)
    #     p = tuple(pred)
    #     tim = cv2.putText(tim,str(index), (int(p[0]+5), int(p[1])), font, 0.4,
    #                           (0, 0, 255), 1)
    #     cv2.circle(tim, p, 1, color, 1, cv2.LINE_AA)

    img_n=tim.copy()

    cv2.fillConvexPoly(img_n, np.array(new_points), 1)
    # cv2.fillPoly(img_n, [np.array(new_points[:3]),np.array(new_points)], 1)
    bitwisexor = cv2.bitwise_xor(img_n, img_origin)

    cv2.imshow("bitwisexor", bitwisexor)

    cv2.imshow('output', tim)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

 

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值