风格迁移开发记录(DCT-Net)

1.DCT-Net部署

阿里旗下的 modelscope社区,丰富的开源风格迁移算法模型
image.png
DCT-Net GitHub链接

git clone https://github.com/menyifang/DCT-Net.git
cd DCT-Net

python run_sdk.py下载不同风格的模型

如下图每个文件夹代表一种风格,有cartoon_bg.pb, cartoon_h.pb两个模型,bg是全图风格模型,h是脸部风格模型:
image.png

模型转换

export_model.py

模型转换方式,不能将pb模型全部转换,要取中间节点,有些前后节点rknn或ncnn不支持需放在cpu处理
pb->tflite->rknn
pb->onnx->ncnn

"""
@File   : export_model.py
@Author : 
@Date   : 2023/12/13
@Desc   : 
"""
import os
import shutil
import tensorflow as tf
import cv2
import tf2onnx
import onnx
import time
import onnxruntime
import subprocess
import numpy as np

# python -m tf2onnx.convert --graphdef .\damo\cv_unet_person-image-cartoon_compound-models\cartoon_bg.pb --output .\damo\cv_unet_person-image-cartoon_compound-models\cartoon_bg.onnx --
# inputs input_image:0 --outputs output_image:0 
def convert_pb2tflite(model_path, input_shape, model_dir, type_name):
    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
            graph_def_file = model_path + '.pb',
            input_arrays = ["strided_slice_1"],
            output_arrays = ["strided_slice_4"],
            input_shapes = {'strided_slice_1' : input_shape}
        )
    # converter = tf.lite.TFLiteConverter.from_frozen_graph(model_path + '.pb', input_arrays=["input_image"], output_arrays=output_name, input_shapes={"input_image": input_shape})
    # converter.optimizations = [tf.lite.Optimize.DEFAULT]
    # converter.target_spec.supported_types = [tf.float16]
    tflite_model = converter.convert()

    bgh = model_path.split('/')[-1]
    save_path = os.path.join(model_dir, type_name + '_' + bgh + '.tflite')
    print('===> ', save_path, bgh)
    # exit()
    open(save_path, "wb").write(tflite_model)
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    input = interpreter.get_input_details()
    print('===> input: ', input)
    output = interpreter.get_output_details()
    print('===> output: ', output)

def pb2tflite_2():
    pb_dir = './damo/'
    bg_model = 'cartoon_bg'
    h_model = 'cartoon_h'
    tflite_dir = './damo/tflite_model'
    # 自定义设置输入shape
    # bg_input_shape = [1, 720, 720, 3]
    # bg_input_shape = [1, 1920, 1080, 3]
    # bg_input_shape = [1, 2560, 1440, 3]
    bg_input_shape = [1, 1280, 720, 3]
    head_input_shape = [1, 288, 288, 3]
    tflite_dir = tflite_dir + str(bg_input_shape[1]) + "x" + str(bg_input_shape[2])
    if not os.path.exists(tflite_dir):
        os.makedirs(tflite_dir)

    for i in os.listdir(pb_dir):
        if not i.startswith('cv_unet'):
            continue
        model_dir = os.path.join(pb_dir, i)
        bg_path = os.path.join(model_dir, bg_model)
        h_path = os.path.join(model_dir, h_model)
        print('============', i)
        type_name = i.split('-')[-2]
        type_name = type_name.split('_')[0]
        print('===> ', i, type_name, bg_path)
        convert_pb2tflite(bg_path, bg_input_shape, tflite_dir, type_name)
        
        # convert_pb2tflite(h_path, head_input_shape, tflite_dir, type_name)
        # exit(0)

def pb2onnx(bg_path, bg_input_shape, onnx_dir, type_name):
    # 定义要执行的命令行命令
    pb_path = bg_path + '.pb'
    onnx_name = type_name + '_' + bg_path.split('/')[-1] + '.onnx'
    onnx_path =  os.path.join(onnx_dir, onnx_name)  
    command = "python -m tf2onnx.convert --graphdef {pb_path} --output {onnx_path} --inputs strided_slice_1:0 --outputs add_1:0 --inputs-as-nchw strided_slice_1:0 --outputs-as-nchw add_1:0"
    # 使用字符串格式化将变量插入命令中
    formatted_command = command.format(pb_path=pb_path, onnx_path=onnx_path)
    # 使用 subprocess.Popen 执行命令
    p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    # 获取命令行输出和错误信息
    output, error = p.communicate()
    # 将二进制输出转换为字符串并打印出来
    print(output.decode())
    time.sleep(5)
    # onnx-sim
    onnx_sim, _ = os.path.splitext(onnx_path)
    onnx_sim_path = onnx_sim + '-sim.onnx'
    n = bg_input_shape[0]
    c = bg_input_shape[3]
    h = bg_input_shape[1]
    w = bg_input_shape[2]
    print(bg_input_shape, onnx_path, onnx_sim_path)
    command2 = 'python -m onnxsim {onnx_path} {onnx_sim_path} --overwrite-input-shape {n},{c},{h},{w}'
    # 使用字符串格式化将变量插入命令中
    formatted_command = command2.format(onnx_path=onnx_path, onnx_sim_path=onnx_sim_path, bg_input_shape=bg_input_shape, n=n, c=c, h=h, w=w)
    # 使用 subprocess.Popen 执行命令
    p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    # 获取命令行输出和错误信息
    output, error = p.communicate()
    # 将二进制输出转换为字符串并打印出来
    print(output.decode())

def tfpb2onnx():
    pb_dir = './damo/'
    bg_model = 'cartoon_bg'
    h_model = 'cartoon_h'
    onnx_dir = './damo/onnx_model'
    # bg_input_shape = [1, 1024, 1024, 3]
    # bg_input_shape = [1, 1920, 1080, 3]
    bg_input_shape = [1, 2560, 2560, 3]
    head_input_shape = [1, 288, 288, 3]
    onnx_dir = onnx_dir + str(bg_input_shape[1]) + "x" + str(bg_input_shape[2])
    if not os.path.exists(onnx_dir):
        os.makedirs(onnx_dir)

    for i in os.listdir(pb_dir):
        if not i.startswith('cv_unet'):
            continue
        model_dir = os.path.join(pb_dir, i)
        bg_path = os.path.join(model_dir, bg_model)
        h_path = os.path.join(model_dir, h_model)
        print(f'============>bg_path: {bg_path}, h_path: {h_path}')
        type_name = i.split('-')[-2].split('_')[0]
        # type_name = type_name.split('_')[0]
        print(f'type_name: {type_name}, i: {i}')
        # exit(0)
        pb2onnx(bg_path, bg_input_shape, onnx_dir, type_name)
        pb2onnx(h_path, head_input_shape, onnx_dir, type_name)
        # exit()
    # tf2onnx
    # python -m tf2onnx.convert --graphdef damo/cv_unet_person-image-cartoon_compound-models/cartoon_bg.pb --output damo/cv_unet_person-image-cartoon_compound-models/cartoon_bg.onnx  
        # --inputs strided_slice_1:0 --outputs add_1:0 --inputs-as-nchw strided_slice_1:0

    # simplifier onnx
    # python -m onnxsim cartoon_bg.onnx cartoon_bg-sim.onnx --overwrite-input-shape 1,3,1024,1024

def onnx2ncnn():
    onnx_dir = './damo/onnx_model2560x2560'
    for i in os.listdir(onnx_dir):
        if not i.endswith('.onnx'):
            continue
        if 'h-sim' in i:
            continue
        onnx_path = os.path.join(onnx_dir, i)
        onnx_name, ext = os.path.splitext(onnx_path)
        # print(onnx_name)
        ncnn_param = onnx_name + '.param'
        ncnn_bin = onnx_name + '.bin'
        print(f'onnx_name: {onnx_name}, ncnn_param: {ncnn_param}, ncnn_bin {ncnn_bin}')
        command3 = "./ncnn-20231027-ubuntu-2204/bin/onnx2ncnn {onnx_path} {ncnn_param} {ncnn_bin}"
        # 使用字符串格式化将变量插入命令中
        formatted_command = command3.format(onnx_path=onnx_path, ncnn_param=ncnn_param, ncnn_bin=ncnn_bin)
        # 使用 subprocess.Popen 执行命令
        p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        # 获取命令行输出和错误信息
        output, error = p.communicate()
        # 将二进制输出转换为字符串并打印出来
        print(output.decode())
        # exit(0)

def ncnn_optimize():
    onnx_dir = './damo/onnx_model1024x1024'
    for i in os.listdir(onnx_dir):
        if not i.endswith('.param'):
            continue
        param_path = os.path.join(onnx_dir, i)
        bin_path = param_path.replace('.param', '.bin')
        opt_param_path = param_path.replace('.param', '-opt.param')
        opt_bin_path = bin_path.replace('.bin', '-opt.bin')
        print(f'param_path: {param_path}, bin_path: {bin_path}, opt_param_path {opt_param_path}, opt_bin_path {opt_bin_path}')
        command3 = "./ncnn-20231027-ubuntu-2204/bin/ncnnoptimize {param_path} {bin_path} {opt_param_path} {opt_bin_path} 1"
        # 使用字符串格式化将变量插入命令中
        formatted_command = command3.format(param_path=param_path, bin_path=bin_path, opt_param_path=opt_param_path, opt_bin_path=opt_bin_path)
        # 使用 subprocess.Popen 执行命令
        p = subprocess.Popen(formatted_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        # 获取命令行输出和错误信息
        output, error = p.communicate()
        # 将二进制输出转换为字符串并打印出来
        print(output.decode())



if __name__ == '__main__':
    pb2tflite_2()
    # tfpb2onnx()
    # onnx2ncnn()
    # ncnn_optimize()

tflite2rknn.py

import os
import time
import shutil
import numpy as np
import cv2
from rknn.api import RKNN


def show_outputs(outputs):
    output = outputs[0][0]
    index = sorted(range(len(output)), key=lambda k : output[k], reverse=True)
    fp = open('./labels.txt', 'r')
    labels = fp.readlines()
    top5_str = 'mobilenet_v1\n-----TOP 5-----\n'
    for i in range(5):
        value = output[index[i]]
        if value > 0:
            topi = '[{:>4d}] score:{:.6f} class:"{}"\n'.format(index[i], value, labels[index[i]].strip().split(':')[-1])
        else:
            topi = '[  -1]: 0.0\n'
        top5_str += topi
    print(top5_str.strip())

def dequantize(outputs, scale, zp):
    outputs[0] = (outputs[0] - zp) * scale
    return outputs

def letterbox(im, new_shape=(640, 640), color=(0, 0, 0)):
    # Resize and pad image while meeting stride-multiple constraints
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return im, ratio, (dw, dh)

def post_process(rknn_result):
    rknn_result = rknn_result.clip(-0.999999, 0.999999)
    rknn_result = (rknn_result + 1) * 127.5
    cartoon_img = rknn_result.astype('uint8')

    # onnx_result = cv2.cvtColor(onnx_result, cv2.COLOR_RGB2BGR)
    # cv2.imwrite('8_anime.jpg', rknn_result)
    return cartoon_img

def export_rknn(tflite_model_path, QUANTIZE_ON, DATASET):
    # Create RKNN object
    rknn = RKNN(verbose=True)

    # Pre-process config
    print('--> Config model')
    # rknn.config(mean_values=[128, 128, 128], std_values=[128, 128, 128], target_platform='rk3566')
    rknn.config(target_platform='rk3588')
    print('done')

    # Load model (from https://www.tensorflow.org/lite/examples/image_classification/overview?hl=zh-cn)
    print('--> Loading model')
    ret = rknn.load_tflite(model=tflite_model_path)
    if ret != 0:
        print('Load model failed!')
        exit(ret)
    print('done')

    # Build model
    print('--> Building model')
    ret = rknn.build(do_quantization=QUANTIZE_ON, dataset=DATASET)
    if ret != 0:
        print('Build model failed!')
        exit(ret)
    print('done')

    # Export rknn model
    print('--> Export rknn model')
    ret = rknn.export_rknn(tflite_model_path.replace('.tflite', '.rknn'))
    if ret != 0:
        print('Export rknn model failed!')
        exit(ret)
    print('done')

    # Init runtime environment
    print('--> Init runtime environment')
    ret = rknn.init_runtime()
    if ret != 0:
        print('Init runtime environment failed!')
        exit(ret)
    print('done')

    # Set inputs
    IMG_PATH = './16.png'
    IMG_SIZE = (288, 288)  # w, h
    img = cv2.imread(IMG_PATH)
    # img, ratio, (dw, dh) = letterbox(img, new_shape=(IMG_SIZE[0], IMG_SIZE[1]))
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, IMG_SIZE)
    # img = img.astype('float32')
    # img = img / 127.5 - 1
    img = np.expand_dims(img, 0)
    print(f'===> input shape: {img.shape}')

    # Inference
    print('--> Running model')
    outputs = rknn.inference(inputs=[img], data_format=['nhwc'])
    print(f'===> output shape: {outputs[0].shape}')
    # np.save('./tflite_mobilenet_v1_qat_0.npy', outputs[0])
    # show_outputs(dequantize(outputs, scale=0.00390625, zp=0))
    cartoon_img = post_process(outputs[0])
    cv2.imwrite(model_path.replace('.tflite', '.jpg'), cartoon_img)
    print('done')

    rknn.release()

if __name__ == '__main__':
    model_dir = './StyleTransfer/DCT-Net-main/damo/tflite_head'
    QUANTIZE_ON = False
    DATASET = './dataset.txt'
    for i in os.listdir(model_dir):
        if not i.endswith('.tflite'):
            continue
        model_path = os.path.join(model_dir, i)
        print(f'model path: {model_path}')
        export_rknn(model_path, QUANTIZE_ON, DATASET)
    

RKNN和NCNN推理代码

GitHub

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值