1、先在pytroch中把torch.save的神经网络参数文件转化为git.trace格式
def convertTracedScriptModule(opt):
# load pretraind net
if opt.input_size >=256:
G = network.Generator(opt.in_ch, opt.out_ch, opt.ngf)
# D = network.Discriminator(opt.in_ch, opt.out_ch, opt.ndf)
else:
G = network.Generator_Small(opt.in_ch, opt.out_ch, opt.ngf)
# D = network.Discriminator_Small(opt.in_ch , opt.out_ch, opt.ndf)
if opt.weights == "not_use":
print("please set the path of pretrained net")
return
ckpt = torch.load(opt.weights)
G.load_state_dict(ckpt['G_model'], strict=False)
epoch = ckpt['epoch']
G.eval()
print(G)
batch_size = 1
example = torch.rand(batch_size,opt.in_ch,opt.input_size, opt.input_size)
traced_script_module = torch.jit.trace(G, example)
output = traced_script_module(torch.ones(batch_size,opt.in_ch,opt.input_size, opt.input_size))
print(output)
# save loaded net
module_name = './traced_script_module/%s_epoch_%d_in_ch_%d_out_ch%d_input_size_%d.pt'%(opt.dataset, epoch, opt.in_ch, opt.out_ch, opt.input_size)
traced_script_module.save(module_name)
print(module_name + " is saved")
2、在c++中加载traced_script_module
定义一个简单的类,TraceScriptModel,用来加载预训练的网络,并执行预测,头文件如下
其中load用来加载预训练的模型,runForward方法执行预测,batch_size为 1是,返回结果是多通道的数据,即三维连续数据,长度是out_channel*out_rows*out*cols, runForwardFacies也执行预测,返回结果是多通道合并后的二维离散数据,离散值是满足门槛值的通道号。
#pragma once
#include <string>
#include <vector>
#include <memory>
#include <torch/script.h>
class TracedScriptModel
{
public:
TracedScriptModel();
TracedScriptModel(int inChannel=1, int inputSize=128);
~TracedScriptModel();
int load(std::string& modelPath);
std::vector<float> runForward(std::vector<float>&inputSeis);
std::vector<int> runForwardFacies(std::vector<float>& inputSeis);
private:
std::shared_ptr<torch::jit::Module> _module = nullptr;
int _inCh=1, _inputSize=128;
float _cutoff = 0.3;
};
#include "TracedScriptModel.h"
TracedScriptModel::TracedScriptModel()
{
}
TracedScriptModel::TracedScriptModel(int inChannel, int inputSize):
_inCh(inChannel), _inputSize(inputSize)
{
}
TracedScriptModel::~TracedScriptModel()
{
}
int TracedScriptModel::load(std::string& modelPath)
{
// Deserialize the ScriptModule from a file using torch::jit::load().
_module = std::make_shared<torch::jit::Module>(torch::jit::load(modelPath));
assert(_module != nullptr);
std::cout << "moudle is loaded ok\n";
for (const auto& subModule : _module->modules()) {
for (const auto& parms: subModule.named_parameters()) {
std::cout << parms.name << std::endl;
//std::cout << parms.value << std::endl;
}
}
return 0;
}
std::vector<float> TracedScriptModel::runForward( std::vector<float>& inputSeis)
{
// transfer data to tensor
torch::Tensor tensor_input = torch::from_blob(inputSeis.data(), { _inputSize, _inputSize, _inCh }, torch::kFloat);
tensor_input = tensor_input.permute({ 2, 0, 1 });
//tensor_input = tensor_input.toType(torch::kFloat);
//tensor_input = tensor_input.div(255);
tensor_input = tensor_input.unsqueeze(0);
//std::cout << tensor_input << std::endl;
// 网络前向计算
// Execute the model and turn its output into a tensor.
at::Tensor output = _module->forward({ tensor_input }).toTensor();
//std::cout << output << std::endl;
std::vector<float> v(output.data_ptr<float>(), output.data_ptr<float>() + output.numel());
return std::move(v);
}
std::vector<int> TracedScriptModel::runForwardFacies(std::vector<float>& inputSeis)
{
// transfer data to tensor
torch::Tensor tensor_input = torch::from_blob(inputSeis.data(), { _inputSize, _inputSize, _inCh }, torch::kFloat);
tensor_input = tensor_input.permute({ 2, 0, 1 });
//tensor_input = tensor_input.toType(torch::kFloat);
//tensor_input = tensor_input.div(255);
tensor_input = tensor_input.unsqueeze(0);
//std::cout << tensor_input << std::endl;
// 网络前向计算
// Execute the model and turn its output into a tensor.
at::Tensor output = _module->forward({ tensor_input }).toTensor();
//std::cout << output << std::endl;
auto tsize = output.sizes();
int out_ch = tsize[1];
int out_rows = tsize[2];
int out_cols = tsize[3];
at::Tensor facies = at::zeros({ out_rows, out_cols }, torch::kInt32);
at::Tensor ones = at::ones({ out_rows, out_cols }, torch::kInt32);
for (int ch = 0; ch < out_ch; ch++) {
facies = at::where(output[0][ch] > _cutoff, ones * ch, facies);
}
std::vector<int> v(facies.data_ptr<int>(), facies.data_ptr<int>() + facies.numel());
return v;
}
3、模块调用和结果显示
设置预训练模型的路径,加载预训练模型,输入所需数据,得到预测结果。通过matplotlibcpp进行显示
void testPreNet(int argc, char* argv[])
{
int inCh = 1;
int inputSize = 128;
int outCh = 3;
TracedScriptModel tsm{ inCh, inputSize };
std::string modelPath = R"(H:\deeplearning\pix2pix_geomodel\pix2pix_lyf\traced_script_module\yuejin_epoch_442_in_ch_1_out_ch3_input_size_128.pt)";
tsm.load(modelPath);
//std::vector<float> inputDat(128 * 128, 0.5);
//tsm.runForward(inputDat);
std::string seisNpyFile = R"(H:\yuejin\yuejin_sample\yuejin_facies_freq6_segI\seismic_848.npy)";
auto inputSeis = EclipseModel::IModel2D<float>::loadNpy(seisNpyFile);
inputSeis /= 800;
/*std::vector<float> output = tsm.runForward(inputDat.grid());
cout << output.size() << endl;
plt::imshow(output.data(), inputSize, inputSize, 1);
plt::show();
plt::imshow(output.data() + inputSize * inputSize, inputSize, inputSize, 1);
plt::show();
plt::imshow(output.data() + inputSize * inputSize * 2, inputSize, inputSize, 1);
plt::show();*/
/*plt::imshow(output.data(), inputSize, inputSize, 3);
plt::show();*/
std::vector<int>output = tsm.runForwardFacies(inputSeis.grid());
std::vector<unsigned char> facies(output.size());
for (int i = 0; i < output.size(); i++) {
facies[i] = output[i];
}
plt::imshow(facies.data(), inputSize, inputSize, 1);
plt::show();
}
效果如下