怎样用libtorch跑多输入节点的网络呢?

libtorch的forward函数输入参数格式为std::vector<IValue>,当网络输入有多个Tensor时,把这些Tensor依次pushback进这个vector即可。

举例说明:

class CAB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CAB, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.sigmod = nn.Sigmoid()

    def forward(self, x1, x2):
        #x1, x2 = x  # high, low
        x = torch.cat([x1, x2], dim=1)
        x = self.global_pooling(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmod(x)
        x2 = x * x2
        res = x2 + x1
        return res

以上CAB为pytorch中定义的网络结构,有x1和x2两个输入

trace代码如下:

CAB_model = CAB(128, 64)
X1 = torch.rand([1, 64, 48, 48])
X2 = torch.rand([1, 64, 48, 48])
Y = CAB_model(X1, X2)
CAB_traced_script_module = torch.jit.trace(CAB_model, (X1, X2))
CAB_traced_script_module.save("./traced/traced_CAB.pt")

libtorch调用代码如下:

torch::jit::script::Module CAB_model;
try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    CAB_model = torch::jit::load("./traced/traced_CAB.pt");
}
catch (const c10::Error& e) {
    std::cerr << "error loading the twoStream model\n";
    return -1;
}

at::Tensor X1_tensor = torch::zeros({ 1, 64, 48, 48 });
at::Tensor X2_tensor = torch::zeros({ 1, 64, 48, 48 });
std::vector<torch::jit::IValue> inputs;
inputs.push_back(X1_tensor);
inputs.push_back(X2_tensor);

auto CAB_output = CAB_model.forward(inputs).toTensor();
c10::IntList size;
size = CAB_output.sizes();
std::cout << size.at(0) << "," << size.at(1) << "," << size.at(2) << "," << size.at(3) << std::endl;

参考链接:

https://github.com/pytorch/pytorch/issues/15523

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值