flask+PRNet实现3D人脸重建换脸服务

1. 换脸流程

采用三维重建的方式重建出参考图的3Dshape,获得相应的颜色空间(identity);重建视频中人脸3Dshape,提取相应的vertices(shape);结合参考图的颜色空间以及目标图的vertices,渲染出更换了identity的face。

离线服务提取待换脸视频中人脸图片的3D定点信息,存放于redis中,由于顶点信息至少需要float32精度存放,导致把顶点信息以float32存放于视频中会,一个value会非常的大,故把视频拆分成了很多段存放于redis中,在取redis中信息时,用多线程的方式取出,然后进行换脸服务,比单个串行服务速度快很多。

2. 涉及的技术

人脸三维重建,图像渲染,图像补全,边缘检测,人分割

人脸三维重建:网络采用PRNet

3. 存在的问题及解决办法

抖动:视频中对人脸更换后出现抖动,通过对人脸检测框进行平滑处理可以有效降低抖动程度,确定抖动由人脸检测精度低造成,目前采用face++人脸检测接口进行人脸检测

边缘伪影:由换脸mask造成对边缘出现伪影,通过设置模板mask以及人脸分割,精确得到换脸mak

眼镜:采用图像补全技术,用边缘检测方法获取mask,根据得到等mask对图像进行补全

眼睛转动及嘴巴张闭:由于图像重建后对眼睛和嘴巴是固定对,在换脸mask上去掉相应区域保留视频中人眼和嘴巴

参考图和原图存在色差:通过颜色校正,调整图像亮度

4. 接口说明

版本: 1.0

描述:传入base64编码的二进制图片数据和视频名,把检测到的人脸通过3D人脸重建替换视频中人脸,根据换脸后的视频帧视频

请求方式:post

请求链接:xxxxxxxxxx:9775/ai/v1/FaceSwap

图片要求:

              图片格式:JPG(JPEG),PNG

              图片像素尺寸:最小 200*200 像素,最大 4096*4096像素

5.  整体架构方案

6. 接口设计

接口请求参数: 

参数名

必选

类型

说明

requestIdString用于区分每一次请求的唯一的字符串id
inputImage

String图片的base64值
videoNameString视频名称
tokenString服务鉴权标识,AI组统一分配
userIdString用户id

 

接口返回结果示例:

{
    "code": 0,
    "msg": "success",
    "data": {
       "requestId": "100022" ,
       "faceSwapRes": True,
       "timeUsed": "30.11962342262268"
    }
 }

 

接口返回参数说明:

参数名

类型

说明

参数名

类型

说明

requestId
String用户请求唯一表示
faceSwapRes
String

换脸服务返回结果,成功True,失败False

timeUsed
Int整个请求所花费的时间,单位为毫秒

 

接口状态码code:

状态码

状态说明

0成功
2未检测出人脸
3鉴权失败
4参数无效
5图片尺寸不符合超出范围
6

请求异常

7. 代码如下:

# encoding:utf-8
from meinheld import server
from flask import Flask, request
from skimage.io import imread, imsave
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED, as_completed
import logging
from logging.handlers import TimedRotatingFileHandler
import json
import base64
import hashlib
from threading import Thread, Lock
from PIL import Image
from io import BytesIO
from conf import config
import os
from api import PRN
from glass_judge import *
# from utils.render import render_texture,render_texture_v1
# from utils.estimate_pose import rotate_pos
import cv2
import redis
# from face_segmentation.face_segment import FaceSegment
from face_segmentation.face_segment import FaceSegmentFCN
from mesh.render import render_colors
from faceDetect.face_detection import FaceDetector, FaceTracker
from face_align import FaceAligner_v1
from Pluralistic.FaceEdit import FaceEditor, CropLayer

app = Flask(__name__)

def setLog():
    log_fmt = '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    formatter = logging.Formatter(log_fmt)
    fh = TimedRotatingFileHandler(
        filename="log/run_faceswap_server" + str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + ".log",
        when="H", interval=1,
        backupCount=72)
    fh.setFormatter(formatter)
    logging.basicConfig(level=logging.INFO)
    log = logging.getLogger()
    log.addHandler(fh)


setLog()

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

model_path = "./checkpoint/snapshot/checkpoint_epoch_1000.pth.tar"

# faceLandMark = FaceLandmarkModel(model_path)
# segmentor = FaceSegment('./face_segmentation/checkpoints/model.pt')
fcn_segmentor = FaceSegmentFCN('./face_segmentation/weights/Keras_FCN8s_face_seg_YuvalNirkin.h5')
MODEL_PATH = './faceDetect/model_new.pb'
face_detector = FaceDetector(MODEL_PATH, gpu_memory_fraction=0.25, visible_device_list='0')
face_aligner = FaceAligner_v1()
# cv2.dnn_registerLayer('Crop', CropLayer)
prn = PRN(is_dlib=True)
editor = FaceEditor()
executor = ThreadPoolExecutor(config.threadPoolSize)

# 创建链接到redis数据库的对象
pool = redis.ConnectionPool(host=config.redisHost, port=config.redisPort, password=config.redisPassword,
                            max_connections=config.maxConnections)
redisDb = redis.Redis(connection_pool=pool)

lock = Lock()
swap_threads = []
frame_dict_list = dict()
all_task = list()
imageList = [""]*5000
frame_count_all = 0
fps = 25
w = 255
h = 255


def colorTransfer(src, dst, mask=None):
    if mask is None:
        h, w, c = dst.shape
        x = np.array(np.arange(w))
        y = np.array(np.arange(h))
        X, Y = np.meshgrid(x, y)
        X = np.reshape(X, (w * h,))
        Y = np.reshape(Y, (w * h,))
        maskIndices = (X, Y)
    else:
        # indeksy nie czarnych pikseli maski
        maskIndices = np.where(mask != 0)
    transferredDst = np.copy(dst)

    # src[maskIndices[0], maskIndices[1]] zwraca piksele w nie czarnym obszarze maski
    maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32)
    maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32)
    meanSrc = np.mean(maskedSrc, axis=0)
    meanDst = np.mean(maskedDst, axis=0)
    maskedDst = maskedDst - meanDst
    maskedDst = maskedDst + meanSrc
    maskedDst = np.clip(maskedDst, 0, 255)
    transferredDst[maskIndices[0], maskIndices[1]] = maskedDst

    return transferredDst


def swapThread(alpha, new_colors, frame_key, frame_val, videoPath):
    start_time = time.time()
    logging.info(f"frame_key is:  {str(frame_key)}")
    global fps
    if frame_key == "fps":
        fps = frame_val.get("fps")

    frame_count = frame_key.split(":")[0]
    frame_val = eval(frame_val)
    vertices = frame_val.get("vertices")
    logging.info("vertices")
    fps = int(float(frame_val.get("fps")))

    new_mask = cv2.imread(videoPath + str(frame_count) + "_new_mask.jpg")
    new_mask = cv2.cvtColor(new_mask, cv2.COLOR_BGR2GRAY)
    new_mask = np.where(new_mask < 1, 0, 1)

    # image = base64.b64decode(image)
    # img = plt.imread(BytesIO(image), "jpg")
    # image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    time_image = time.time()
    image = cv2.imread(videoPath + str(frame_count) + ".jpg")
    if image is None or not image.data or len(image) < 1:
        return False
    global h
    global w
    [h, w, _] = image.shape
    im_size = (w, h)
    vertices = np.fromstring(vertices, dtype=np.float32)
    # vertices = np.fromstring(vertices, dtype=np.float16)
    vertices = vertices.astype(np.float32).copy()
    vertices = vertices.reshape((43867, -1))  # (43867,3)
    new_image = render_colors(vertices, prn.triangles, new_colors, h, w)  #3D人脸融合
    new_image = (255 * new_image).astype(np.uint8)
    # 去掉嘴部mask,目的保留视频中人脸嘴部,使得嘴部可以张开漏出牙齿
    # 根据视频中人脸颜色,校正渲染出的人脸的颜色
    # new_image = correct_colours(image, new_image, landmark[:,:2])
    new_image = colorTransfer(image, new_image, new_mask)
    # 合并渲染出的人脸和视频中的人脸
    swap_image = image * (1 - new_mask[:, :, np.newaxis]) + \
                 new_image * alpha * new_mask[:, :, np.newaxis] + \
                 image * (1 - alpha) * new_mask[:, :, np.newaxis]
    # 得到泊松缝合中心位置
    r = cv2.boundingRect((new_mask * 255).astype(np.uint8))
    center = ((r[0] + np.round(r[2] / 2), r[1] + np.round(r[3] / 2)))
    center = tuple(map(int, center))

    if image is None or not image.data or len(image) < 1:
        return False

    output = cv2.seamlessClone(swap_image.astype(np.uint8), image,
                               (new_mask * 255).astype(np.uint8), center, cv2.NORMAL_CLONE)
    out = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

    if out is None or not out.data or len(out) < 1:
        return False
    logging.info(f"swap merge face cost time is:  {str(time.time() - start_time)}")
    # print(f"swap merge face cost time is:  {str(time.time() - start_time)}")
    time1 = time.time()
    ret, buf = cv2.imencode(".jpg", out)
    out_base64 = base64.b64encode(buf)
    lock.acquire()
    global imageList
    imageList[int(float(frame_count))] = out_base64
    lock.release()

    return True


def faceSwapRun(alpha, new_colors, frame_dict, videoPath, imageList):
    fps = 25
    w = 255
    h = 255

    for frame_key, frame_val in frame_dict.items():
        start_time = time.time()
        logging.info(f"frame_key is:  {str(frame_key)}")
        # global fps
        if frame_key == "fps":
            fps = frame_val.get("fps")
            continue

        frame_count = frame_key.split(":")[0]
        frame_val = eval(frame_val)
        vertices = frame_val.get("vertices")
        logging.info("vertices")
        fps = int(float(frame_val.get("fps")))
        new_mask = cv2.imread(videoPath + str(frame_count) + "_new_mask.jpg")
        new_mask = cv2.cvtColor(new_mask, cv2.COLOR_BGR2GRAY)
        new_mask = np.where(new_mask < 1, 0, 1)
        time_image = time.time()
        image = cv2.imread(videoPath + str(frame_count) + ".jpg")
        if image is None or not image.data or len(image) < 1:
            continue

        [h, w, _] = image.shape
        im_size = (w, h)
        vertices = np.fromstring(vertices, dtype=np.float32)
        # vertices = np.fromstring(vertices, dtype=np.float16)
        vertices = vertices.astype(np.float32).copy()
        vertices = vertices.reshape((43867, -1))  # (43867,3)

        new_image = render_colors(vertices, prn.triangles, new_colors, h, w)  # 从这开始 结合
        new_image = (255 * new_image).astype(np.uint8)
        # 去掉嘴部mask,目的保留视频中人脸嘴部,使得嘴部可以张开漏出牙齿
        # 根据视频中人脸颜色,校正渲染出的人脸的颜色
        # new_image = correct_colours(image, new_image, landmark[:,:2])
        new_image = colorTransfer(image, new_image, new_mask)
        print(new_image.shape)
        print(image.shape)
        # 合并渲染出的人脸和视频中的人脸
        swap_image = image * (1 - new_mask[:, :, np.newaxis]) + \
                     new_image * alpha * new_mask[:, :, np.newaxis] + \
                     image * (1 - alpha) * new_mask[:, :, np.newaxis]

        # 得到泊松缝合中心位置
        r = cv2.boundingRect((new_mask * 255).astype(np.uint8))
        center = ((r[0] + np.round(r[2] / 2), r[1] + np.round(r[3] / 2)))
        center = tuple(map(int, center))

        if image is None or not image.data or len(image) < 1:
            continue

        output = cv2.seamlessClone(swap_image.astype(np.uint8), image,
                                   (new_mask * 255).astype(np.uint8), center, cv2.NORMAL_CLONE)
        out = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)

        if out is None or not out.data or len(out) < 1:
            continue
        logging.info(f"swap merge face cost time is:  {str(time.time() - start_time)}")
        # print(f"swap merge face cost time is:  {str(time.time() - start_time)}")
        time1 = time.time()
        ret, buf = cv2.imencode(".jpg", out)
        out_base64 = base64.b64encode(buf)
        print(f"encode base64 cost time is:  {str(time.time() - time1)}")

        print("frame_count is : ", frame_count)
        # global imageList
        imageList[int(float(frame_count))] = out_base64

    return imageList, fps, w, h


def get_redis(video_key, redisDb, i, alpha, new_colors, videoPath):
    logging.info("key is:  " + video_key)
    frame_dict = eval(redisDb.get(video_key))
    global frame_count_all
    frame_count_all += len(frame_dict)
    for frame_key, frame_val in frame_dict.items():
        swapThread(alpha, new_colors, frame_key, frame_val, videoPath)

    return True


def get_redis1(video_key, redisDb, i, alpha, new_colors, videoPath):
    logging.info("key is:  " + video_key)
    frame_dict = eval(redisDb.get(video_key))
    global frame_count_all
    frame_count_all += len(frame_dict)
    global frame_dict_list
    # frame_dict_list.append(frame_dict)
    frame_dict_list.update(frame_dict)

    return True


def faceSwap(ref_image, video_id, prn, videoPath):
  try:
    begin_time = time.time()
    # 人脸加权比例
    alpha = 0.8
    # read referance image and get the color space
    # ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
    ref_image = face_aligner.aligner(ref_image)
    h, w, _ = ref_image.shape
    boxes, _ = face_detector(ref_image)
    ref_pos = prn.process(ref_image, image_info=boxes[0])
    # ref_pos = prn.process(ref_image)
    logging.info("faceDetector and prn.process  cost time:  " + str(time.time() - begin_time))

    ref_image = ref_image / 255.
    ref_texture = cv2.remap(ref_image, ref_pos[:, :, :2].astype(np.float32), None, interpolation=cv2.INTER_NEAREST,
                            borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
    new_colors = prn.get_colors_from_texture(ref_texture)  # 获取重建出来的ref_texture上的点的颜色值
    logging.info("to remap get colors  cost time:  " + str(time.time() - begin_time))

    # 获取脸部mask颜色值
    redis_time = time.time()

    global all_task
    global frame_count_all
    frame_count_all = 0
    for i in range(120):
        video_key = video_id + "-" + str(i + 1)
        if redisDb.exists(video_key):
            # redis_thread = executor.submit(get_redis1, video_key, redisDb, str(i + 1), alpha, new_colors, videoPath)
            redis_thread = executor.submit(get_redis, video_key, redisDb, str(i + 1), alpha, new_colors, videoPath)
            all_task.append(redis_thread)

    # executor.shutdown(wait=True)
    # wait(all_task, return_when=ALL_COMPLETED)
    for future in as_completed(all_task):
        data = future.result()
        logging.info(f"in main: get page {str(data)}s success")

    # frame_dict = eval(redisDb.get(video_id))
    logging.info("get redis val cost time:  " + str(time.time() - redis_time))
    print("get redis val cost time:  " + str(time.time() - redis_time))

    logging.info("frame_dict_list len is:  " + str(len(frame_dict_list)))
    # 提取关键点
    logging.info("threads swap face cost time:  " + str(time.time() - begin_time))

    # imageList, fps, w, h = faceSwapRun(alpha, new_colors, frame_dict, videoPath, imageList)
    # return imageList, fps, w, h
    # global swap_threads

    # for frame_key, frame_val in frame_dict_list.items():
    #     # swapThread(alpha, new_colors, frame_key, frame_val, videoPath, imageList)
    #
    #     thread = Thread(target=swapThread, args=(alpha, new_colors, frame_key, frame_val, videoPath))
    #     swap_threads.append(thread)
    #     thread.start()

    # for t in swap_threads:
    #     t.join()
    return True
  except Exception as ex:

      logging.exception(ex)
      return False


@app.route('/ai/v1/FaceSwap', methods=['POST'])
def faceSwapMethod():
    try:
        start_time = time.time()
        resParm = request.data
        # 转字符串
        resParm = str(resParm, encoding="utf-8")
        resParm = eval(resParm)

        requestId = resParm.get('requestId')
        # 服务鉴权
        token = resParm.get('token')
        if not token:
            res = {'code': 3, 'msg': 'token fail'}
            logging.error("code: 3 msg:  token fail ")
            return json.dumps(res)
        videoId = resParm.get("videoName")
        if videoId is None or videoId.strip() == '':
            res = {'code': 7, 'msg': 'videoName is null'}
            logging.error("code: 3 msg:  videoName is null")

        # 按照debase64进行处理
        modelImg_base64 = resParm.get("inputImage")
        if not modelImg_base64:
            res = {'code': 4, 'msg': ' picture param invalid'}
            logging.error("code: 4  msg:  picture param invalid")
            return json.dumps(res)
        modelImg_data_1 = None
        if is_has_glass(modelImg_base64):
            modelImg = base64.b64decode(modelImg_base64)
            modelImg_data = np.fromstring(modelImg, np.uint8)
            modelImg_data_1 = cv2.imdecode(modelImg_data, cv2.IMREAD_COLOR)

            image = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
            res = editor.removeglasses(image)
            modelImg_data_1 = res[0]
            img = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
            cv2.imwrite("glass_img.jpg", img)
        else:

            modelImg = base64.b64decode(modelImg_base64)
            # recv_time = time.time()
            # logging.info(f"recv image cost time:  {str(recv_time - start_time)}")
            modelImg_data = np.fromstring(modelImg, np.uint8)
            modelImg_data_1 = cv2.imdecode(modelImg_data, cv2.IMREAD_COLOR)
        # cv2.imwrite("modelImg.jpg", modelImg_data_1)
        # 判定图片尺寸
        if modelImg_data_1.shape[0] > config.size or modelImg_data_1.shape[1] > config.size:
            res = {'code': 5, 'msg': ' picture size invalid'}
            logging.error("code: 5 msg: picture size invalid")
            return json.dumps(res)
        logging.info(f"modelImg_data_1  shape:  {str(modelImg_data_1.shape)}   size:  {str(modelImg_data_1.size)}")

        time_predict = time.time()
        # cv2.imwrite("upload_ref.jpg", modelImg_data_1)
        modelImg_data_1 = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
        swapRes = gen_swap_face(modelImg_data_1, videoId, prn)

        logging.info(f"face swap cost Time is: {str(time.time() - time_predict)} ")
        for t in swap_threads:
            t.join()

        timeUsed = time.time() - start_time
        data = {'requestId': requestId, 'faceSwapRes': str(swapRes), 'timeUsed': str(timeUsed)}
        res = {'code': 0, 'msg': 'success', 'data': data}
        logging.info(f"code:0  msg:success  face swap cost Time is: {str(timeUsed)} ")
        return json.dumps(res)
    except Exception as e:
        logging.exception(e)
        res = {'code': 6, 'msg': 'request exception'}
        return json.dumps(res)


def gen_swap_face(modelImg_data_1, videoId, prn):
    try:
        videoName = os.path.basename(videoId)
        videoName = videoName.split(".")[0]
        refImgMd = hashlib.md5(modelImg_data_1).hexdigest()
        videoPath = './img_video/' + videoName + "/"

        save_res_path = './img_video/' + videoName + "/" + refImgMd + "/"
        if not os.path.exists(save_res_path):
            os.makedirs(save_res_path)

        time_predict = time.time()
        # imageList, fps, w, h = faceSwap(modelImg_data_1, videoId, prn, videoPath)
        swapRes = faceSwap(modelImg_data_1, videoId, prn, videoPath)
        if not swapRes:
            return False

        print(f"face swap Method cost Time is: {str(time.time() - time_predict)} ")

        global imageList
        global fps
        global w
        global h
        global frame_count_all

        im_size = (w, h)
        out = None
        logging.info(f"imageList len is: {str(len(imageList))}")
        if len(imageList) < 1:
            return False

        start_time = time.time()

        # for image in imageList:
        for i in range(frame_count_all):
            image = imageList[i]

            if image is None or len(image) < 1:
                continue

            image = base64.b64decode(image)
            img = plt.imread(BytesIO(image), "jpg")
            image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if out is None:
                fourcc = cv2.VideoWriter_fourcc(*"mp4v")
                out = cv2.VideoWriter(save_res_path + videoName + "-" + refImgMd + ".mp4", fourcc, fps, im_size, True)
            out.write(image)
            # logging.info(f"imageList len is: {str(len(imageList))}")
            # logging.info(f"img_size is: {str(im_size)}")
            # print(str(i) + "index  frame_count_all  ", frame_count_all)
        logging.info("image List to merge face video cost:  " + str(time.time() - start_time))
        return True
    except Exception as x:
        logging.exception(x)
        return False


def save_video_face(videoName):
    cap = cv2.VideoCapture(videoName)
    fps = cap.get(cv2.CAP_PROP_FPS)
    im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

    out = None
    frameId = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        if frameId < 350:
            frameId += 1
            continue
        if out is None:
            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
            out = cv2.VideoWriter("test_swap1_200f.mp4", fourcc, fps, im_size, True)
        out.write(frame)
        frameId += 1


if __name__ == "__main__":
    logging.info('Starting the server...')
    server.listen(("0.0.0.0", 9775))
    server.run(app)
    # app.run(host='0.0.0.0', port=18885, threaded=True)

 

Flask apriori是一种基于Apriori算法实现的电商推荐系统。Apriori算法是一种用于挖掘频繁项集的经典算法,它可以从大规模的交易数据中发现关联规则。 在电商推荐系统中,我们可以使用Apriori算法来分析用户的购买记录,找出频繁购买的商品组合。通过这些频繁项集,我们可以推测用户可能对其他商品感兴趣。 Flask是一个轻量级的Python Web框架,可以方便地搭建网站和应用程序。我们可以使用Flask实现电商推荐系统的前后端功能。 首先,我们需要创建一个Flask应用程序。通过Flask提供的路由,我们可以定义不同的页面和功能。 在后端,我们可以使用Python实现Apriori算法。首先,我们需要读取用户的购买记录数据,并将其转化为事务集的形式。然后,我们可以调用Apriori算法来获取频繁项集和关联规则。 在前端,我们可以使用HTML、CSS和JavaScript来构建用户界面。我们可以设计一个商品推荐页面,显示给用户一些可能感兴趣的商品。通过JavaScript,我们可以实现商品的点击事件和添加购物车功能。 在Flask中,我们可以将后端的计算结果传递给前端页面。通过Flask提供的模板引擎,我们可以使用简单的语法将动态数据渲染到页面上。 总结而言,Flask apriori实现电商推荐系统的过程涉及到数据预处理、Apriori算法的实现以及前后端的交互。通过这种方法,我们可以基于用户的购买记录为其推荐相关的商品,提升用户的购物体验。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值