在c++中利用libtorch部署python中训练的pytorch网络

1、先在pytroch中把torch.save的神经网络参数文件转化为git.trace格式

def convertTracedScriptModule(opt):

    # load pretraind net
    if opt.input_size >=256:
        G = network.Generator(opt.in_ch, opt.out_ch, opt.ngf)
        # D = network.Discriminator(opt.in_ch, opt.out_ch, opt.ndf)
    else:
        G = network.Generator_Small(opt.in_ch, opt.out_ch, opt.ngf)
        # D = network.Discriminator_Small(opt.in_ch , opt.out_ch, opt.ndf)

    if opt.weights == "not_use":
        print("please set the path of pretrained net")
        return
    ckpt = torch.load(opt.weights)
    G.load_state_dict(ckpt['G_model'], strict=False)
    epoch = ckpt['epoch']
    G.eval()
    print(G)

    batch_size = 1
    example = torch.rand(batch_size,opt.in_ch,opt.input_size, opt.input_size)
    traced_script_module = torch.jit.trace(G, example)
    output = traced_script_module(torch.ones(batch_size,opt.in_ch,opt.input_size, opt.input_size))
    print(output)

    # save loaded net
    module_name = './traced_script_module/%s_epoch_%d_in_ch_%d_out_ch%d_input_size_%d.pt'%(opt.dataset, epoch, opt.in_ch, opt.out_ch, opt.input_size)
    traced_script_module.save(module_name)
    print(module_name + " is saved")
    

2、在c++中加载traced_script_module

定义一个简单的类,TraceScriptModel,用来加载预训练的网络,并执行预测,头文件如下

其中load用来加载预训练的模型,runForward方法执行预测,batch_size为 1是,返回结果是多通道的数据,即三维连续数据,长度是out_channel*out_rows*out*cols, runForwardFacies也执行预测,返回结果是多通道合并后的二维离散数据,离散值是满足门槛值的通道号。

#pragma once

#include <string>
#include <vector>
#include <memory>
#include <torch/script.h>


class TracedScriptModel
{
public:
	TracedScriptModel();
	TracedScriptModel(int inChannel=1, int inputSize=128);
	~TracedScriptModel();

	int load(std::string& modelPath);

	std::vector<float> runForward(std::vector<float>&inputSeis);

	std::vector<int> runForwardFacies(std::vector<float>& inputSeis);

private:
	std::shared_ptr<torch::jit::Module> _module = nullptr;
	int _inCh=1, _inputSize=128;
	float _cutoff = 0.3;
};

#include "TracedScriptModel.h"

TracedScriptModel::TracedScriptModel()
{
}

TracedScriptModel::TracedScriptModel(int inChannel, int inputSize):
    _inCh(inChannel), _inputSize(inputSize)
{

}

TracedScriptModel::~TracedScriptModel()
{
}

int TracedScriptModel::load(std::string& modelPath)
{
    // Deserialize the ScriptModule from a file using torch::jit::load().
    _module = std::make_shared<torch::jit::Module>(torch::jit::load(modelPath));

    assert(_module != nullptr);
    std::cout << "moudle is loaded ok\n";

    for (const auto& subModule : _module->modules()) {
        for (const auto& parms: subModule.named_parameters()) {
            std::cout << parms.name << std::endl;
            //std::cout << parms.value << std::endl;
        }
    }

	return 0;
}

std::vector<float> TracedScriptModel::runForward( std::vector<float>& inputSeis)
{
    // transfer data to tensor
    torch::Tensor tensor_input = torch::from_blob(inputSeis.data(), { _inputSize, _inputSize, _inCh }, torch::kFloat);
    tensor_input = tensor_input.permute({ 2, 0, 1 });
    //tensor_input = tensor_input.toType(torch::kFloat);
    //tensor_input = tensor_input.div(255);

    tensor_input = tensor_input.unsqueeze(0);
    //std::cout << tensor_input << std::endl;

    // 网络前向计算
    // Execute the model and turn its output into a tensor.
    at::Tensor output = _module->forward({ tensor_input }).toTensor();
    //std::cout << output << std::endl;
    std::vector<float> v(output.data_ptr<float>(), output.data_ptr<float>() + output.numel());
    return std::move(v);
      
}

std::vector<int> TracedScriptModel::runForwardFacies(std::vector<float>& inputSeis)
{
    // transfer data to tensor
    torch::Tensor tensor_input = torch::from_blob(inputSeis.data(), { _inputSize, _inputSize, _inCh }, torch::kFloat);
    tensor_input = tensor_input.permute({ 2, 0, 1 });
    //tensor_input = tensor_input.toType(torch::kFloat);
    //tensor_input = tensor_input.div(255);

    tensor_input = tensor_input.unsqueeze(0);
    //std::cout << tensor_input << std::endl;

    // 网络前向计算
    // Execute the model and turn its output into a tensor.
    at::Tensor output = _module->forward({ tensor_input }).toTensor();
    //std::cout << output << std::endl;
    auto tsize = output.sizes();
    int out_ch = tsize[1];
    int out_rows = tsize[2];
    int out_cols = tsize[3];
    at::Tensor facies = at::zeros({ out_rows, out_cols }, torch::kInt32);
    at::Tensor ones = at::ones({ out_rows, out_cols }, torch::kInt32);
    for (int ch = 0; ch < out_ch; ch++) {
        facies = at::where(output[0][ch] > _cutoff, ones * ch, facies);
    }

    std::vector<int> v(facies.data_ptr<int>(), facies.data_ptr<int>() + facies.numel());

    return v;
}

3、模块调用和结果显示

设置预训练模型的路径,加载预训练模型,输入所需数据,得到预测结果。通过matplotlibcpp进行显示

void testPreNet(int argc, char* argv[])
{
    int inCh = 1;
    int inputSize = 128;
    int outCh = 3;
    TracedScriptModel tsm{ inCh, inputSize };
    std::string modelPath = R"(H:\deeplearning\pix2pix_geomodel\pix2pix_lyf\traced_script_module\yuejin_epoch_442_in_ch_1_out_ch3_input_size_128.pt)";
    tsm.load(modelPath);
    //std::vector<float> inputDat(128 * 128, 0.5);
    //tsm.runForward(inputDat);

    std::string seisNpyFile = R"(H:\yuejin\yuejin_sample\yuejin_facies_freq6_segI\seismic_848.npy)";
    auto inputSeis = EclipseModel::IModel2D<float>::loadNpy(seisNpyFile);
    inputSeis /= 800;
    /*std::vector<float> output = tsm.runForward(inputDat.grid());
    cout << output.size() << endl;
    plt::imshow(output.data(), inputSize, inputSize, 1);
    plt::show();
    plt::imshow(output.data() + inputSize * inputSize, inputSize, inputSize, 1);
    plt::show();
    plt::imshow(output.data() + inputSize * inputSize * 2, inputSize, inputSize, 1);
    plt::show();*/
    /*plt::imshow(output.data(), inputSize, inputSize, 3);
    plt::show();*/
    std::vector<int>output = tsm.runForwardFacies(inputSeis.grid());
    std::vector<unsigned char> facies(output.size());
    for (int i = 0; i < output.size(); i++) {
        facies[i] = output[i];
    }
    plt::imshow(facies.data(), inputSize, inputSize, 1);
    plt::show();
}

效果如下

 

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

oceanstonetree

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值