[LibTorch] 简单的例程

#include <torch/torch.h>
#include <iostream>

struct NetImpl : torch::nn::Module {
    NetImpl(int k) : 
        conv1(torch::nn::Conv2dOptions(k, 256, 1)
                .bias(false)),
        batch_norm1(256),
        conv2(torch::nn::Conv2dOptions(256, 128, 1)
                .stride(2)
                .padding(1)
                .bias(false)),
        batch_norm2(128),
        conv3(torch::nn::Conv2dOptions(128, 64, 1)
                .stride(2)
                .padding(1)
                .bias(false)),
        batch_norm3(64),
        conv4(torch::nn::Conv2dOptions(64, 1, 1)
                .stride(2)
                .padding(1)
                .bias(false))
    {
        // register_module() is needed if we want to use the parameters() method later on
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("conv4", conv4);
        register_module("batch_norm1", batch_norm1);
        register_module("batch_norm2", batch_norm2);
        register_module("batch_norm3", batch_norm3);
    }

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(batch_norm1(conv1(x)));
        x = torch::relu(batch_norm2(conv2(x)));
        x = torch::relu(batch_norm3(conv3(x)));
        x = torch::tanh(conv4(x));
        return x;
    }

    torch::nn::Conv2d conv1, conv2, conv3, conv4;
    torch::nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(Net);

int main() {
    torch::Device device = torch::kCPU;
    if (torch::cuda::is_available()) {
        std::cout << "CUDA is available! Training on GPU." << std::endl;
        device = torch::kCUDA;
    }
    Net net(3);
    net->to(device);
    net->eval();
    auto x = torch::ones({1, 3, 224, 224}, torch::kFloat).to(device);
    torch::NoGradGuard no_grad;
    auto out = net->forward(x);
    std::cout << out << std::endl;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值