人脸识别:史上最详细人脸识别adaface讲解-ckpt转onnx模型--第三节

4 篇文章 1 订阅
4 篇文章 2 订阅

这章节我会讲解的是我在工作上的项目,人脸识别adaface,以下的讲解为个人的看法,若有地方说错的我会第一时间纠正,如果觉得博主讲解的还可以的话点个赞,就是对我最大的鼓励~

上一章节我们讲到了模型的训练与测试,相信大家已经训练好了模型,当然我们训练出来的模型是需要部署到实际项目上面去的,这个就需要把ckpt模型转成onnx模型,由于该作者没有提供转onnx模型的代码,这里我把自己编写的转onnx代码分享给大家。

当然,我们先上代码,对代码进行讲解。
1、首先我们需要填写的参数有4个:
训练出来的模型路径(–ckpt_path)
输出的模型名字(–onnx_name)
主干网络名字(ir_18)
推理图片路径(–image)
2、我们从main函数去看,opencv读取一张图片,并且对图片进行一些格式的转换,最后变成一个4维的tensor。
3、加载训练的主干网络后进行onnx的模型转换。
4、onnx模型出来之后,当然我们需要推理一张图片,推理结果需要跟我们的ckpt模型的输出结果是一致的,这样我们才说onnx模型是对齐的,才能去转其他的模型。

import argparse
import cv2
import numpy as np
import onnxruntime
import torchvision.transforms as transforms
import torch
import net

parser = argparse.ArgumentParser(description='onnx inference')
parser.add_argument('--ckpt_path', default=r'./experiments/default_08-01_1/epoch=24-step=142174.ckpt', type=str, required=True, help='')
parser.add_argument('--onnx_name', default='ada_Face_142174', type=str, required=True, help='')
parser.add_argument('--model_name', default=None, type=str, required=True, help='')
parser.add_argument('--onnx_path', default=None, type=str, required=True, help='')
parser.add_argument('--image', default=None, type=str, required=True, help='')
args = parser.parse_args()

adaface_models = {
    'ir_18':r".\experiments\default_08-01_1\epoch=24-step=142174.ckpt".replace('\\','/'),
}
def load_pretrained_model(architecture='ir_18'):
    # load model and pretrained statedict
    assert architecture in adaface_models.keys()
    model = net.build_model(architecture)
    statedict = torch.load(adaface_models[architecture])['state_dict']
    model_statedict = {key[6:]:val for key, val in statedict.items() if key.startswith('model.')}
    model.load_state_dict(model_statedict)
    model.eval()
    return model

def l2_norm(x):
    """ l2 normalize
    """
    output = x / np.linalg.norm(x)
    return output


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def get_test_transform():
    test_transform = transforms.Compose([    
            transforms.ToTensor(),    
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    return test_transform


def main():
    # img = cv2.imread(args.image)
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # img = get_test_transform()(img)
    # img = img.unsqueeze_(0)
    # print(img.shape)
    img1 = cv2.imread(args.image)
    img1 = np.expand_dims(img1, axis=0)
    img1 = img1.astype(np.float32)
    img1 = img1.transpose(0, 3, 1, 2)
    img1 = torch.from_numpy(img1)
    model = load_pretrained_model('ir_18')
    model.eval()
    torch_out = model(img1)
    torch.onnx.export(model,
                      img1,
                      "%s.onnx" % args.onnx_name,
                      export_params=True,
                      opset_version=11,
                      do_constant_folding=True,
                      input_names=['input'],
                      output_names=['output'])
    print("finished")
    onnx_path = args.onnx_path
    session = onnxruntime.InferenceSession(onnx_path)
    inputs = {session.get_inputs()[0].name: to_numpy(img1)}
    outs = session.run(None, inputs)[0]
    # print(outs.shape)
    outs = l2_norm(outs).squeeze()
    num = torch_out[0].detach().numpy() * outs
    de1 = np.linalg.norm(torch_out[0].detach().numpy()) * np.linalg.norm(outs)
    print(np.sum(num) / de1)



if __name__ == '__main__':
    main()

现在,我们的onnx模型推理结果与ckpt模型的推理结果需要做一个余弦相似度的对比,如果结果无限接近于1,那证明两边的模型结果对齐。

这里我们可以看到torch_out和outs的值,感官上看是对齐的,但是我们还需要进行余弦相似度的对比。
结果输出为1,证明两边的模型是对齐的,这样我们就可以往下面的模型去转换了。
在这里插入图片描述
在这里插入图片描述

到这里我们的ckpt模型转onnx模型已经告一段落啦,相信大家都已经完成了这不操作,接下来我会分享我自己对该adaface网络的一个详细解答。希望大家能多多支持。如果觉得博主讲解的可以的话,点个赞就是我最大的动力,谢谢~
  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值