record----pytorch转libtorch详解

pytorch转libtorch使用记录

前言

记录转换内容:pytorch->libtorch
环境要求:torch->libtorch转换,需要的环境pytorch版本要和c++工程中使用的libtorch第三方库版本一致,即转换版本要和使用版本一致,但是用于训练.pth模型的pytorch版本可以不一致,例如训练的时候选择pytorch1.4,转换使用pytorch1.0,c++中使用的是libtorch1.0.
注意:训练用高版本,转换用低版本,可能转换不出,因为网络结构对pytorch版本的依赖。libtorch1.0以上版本在c++中肯能造成内存泄露,高版本要谨慎使用。

code

以下是在pytorch1.0环境下

import numpy as np
from pynvml import *     #pip install nvidia-ml-py3,显存监控
import torch.onnx
import torch.backends.cudnn as cudnn
from models_class.pose_resnet import *
from models_class.slowfastnet import *
import torchvision.transforms as transforms
from PIL import Image

from models_class.resnet_chejian import *  #导入需要转换的模型的网络结构

# 将torch模型转为libtorch模型
def gen_torch_to_libtorch(model_path, save_path):
    # 设置cuda
    if torch.cuda.is_available():
    	print("ok cuda")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    cudnn.benchmark = True
    #加载网络结构
    model = get_pose_net() #根据自己的网络进行加载,此处需要修改
    
    #×××××第一种方式加载网络权重××××#
    #此方式可以加载训练中的各种信息,但是不包括网络结构,训练保存模型时需要设置相关参数#
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])  # 加载模型参数,转换时必须需要
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化参数,转换时可以不需要
    epoch = checkpoint['epoch']  # 加载epoch,可以用于更新学习率等,转换时可以不需要
    #训练保存模型时,需要进行以下操作
    # 保存模型的状态,可以设置一些参数,后续训练可以使用,例如中断后可以继续训练
    # state = {'epoch': epoch + 1,  # 保存的当前轮数
    #          'state_dict': model.state_dict(),  # 训练好的参数
    #          'optimizer': optimizer.state_dict(),  # 优化器参数,为了后续的resume
    #          'best_pred': best_pred  # 当前最好的精度
    #             , ...., ...}
    # torch.save(state, ‘model_path’)

	#×××××第二种方式加载网络权重××××#
	#此方式仅加载模型权重参数,不加载其他任何信息#
	model.load_state_dict(torch.load(model_path))
	#训练保存模型时,需要进行以下操作
	#torch.save(model.module.state_dict(),model_path)
	
	#×××××第三种方式加载网络权重××××#
	#此方式不仅加载模型权重参数,而且加载了网络结构,两者放在了一起#
	model = torch.load(model_path)  #进行此步则不需要进行加载网络结构,一部进行完成
	#训练保存模型时,需要进行以下操作
	#torch.save(model,model_path)
	
	#将model导入到gpu中
	model.to(device)
    model.cuda()
    #进行eval操作,必须要进行,而且是在此处进行
    model.eval() #必须要在此处进行eval()
    
    #根据网络输入设置一个输入tensor
    example = torch.rand(1, 3, 224, 224).cuda()
    #进行转换保存
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save(save_path)

# 加载torch模型进行测试示例
def load_torch_model_test(img_path,model_path):
    # 初始化模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    cudnn.benchmark = True
    model = get_pose_net()
    model = model.cuda()
    model.to(device)

    # 加载模型权重
    model.load_state_dict(torch.load(model_path)['state_dict'])
    # 进行torch模型测试
    test_trans = transforms.Compose([transforms.Resize((256, 256)),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    img = Image.open(img_path).convert('RGB')
    img = test_trans(img)
    # img = Image.fromarray(np.expand_dims(img, 0))  # 扩张0维度
    imgblob = img.unsqueeze(0)  # tensor禁止进行维度删除
    imgblob = torch.autograd.Variable(imgblob).cuda()  # 对输入进行数据格式转换,np.array->tensor.cuda.float

    torch.no_grad()  # 不进行梯度计算
    model.eval()  # 使用验证模式
    heatmap = model(imgblob)
    print(heatmap)
    return model


if __name__ == "__main__":

    #显存监控
    GPU_USE=0
    nvmlInit() #初始化
    handle = nvmlDeviceGetHandleByIndex(GPU_USE) #获得指定GPU的handle
    info_begin = nvmlDeviceGetMemoryInfo(handle) #获得显存信息
    #路径设置
    img_str ="/data_1/Working/project/torch-onnx-tvm/data/000d81205274490bb1729e4e26226f01-3-1@00.jpg"
    model_path = "/data1/Working/project/plate_project/plate_clearity/pytorch/exam-0/resnet29_cls-10-regular.pth"
    libtorch_save_path = model_path.split('.')[0]+".pt"
    #进行转换
    gen_torch_to_libtorch(model_path,libtorch_save_path)
    # #输出显存使用mb
    info_end = nvmlDeviceGetMemoryInfo(handle)
    print("-"*15+"TORCH GPU MEMORY INFO"+"-"*15)
    print("       Memory Total: "+str(info_end.total//(1024**2)))
    print("       Memory Free: "+str(info_end.free//(1024**2)))
    print("       Memory Used: "+str(info_end.used//(1024**2)-info_begin.used//(1024**2)))
    print("-" * 40)

    
	
	

	
	
	
	
		

	
	
	
    
    
    
    
    
    





  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值