The C++ Frontend

PyTorch C++ Frontend是一个用于 CPU 和 GPU 张量计算的 C++14 库,具有用于最先进机器学习应用程序的自动微分和高级构建块。

描述

PyTorch C++ Frontend可以被认为是 PyTorch Python Frontend的 C++ 版本,为机器学习和神经网络提供自动微分和各种更高级别的抽象。

具体来说,它由以下组件组成:

组件描述
torch::Tensor可自动微分、高效的 CPU 和 GPU 启用张量
torch::nn用于神经网络建模的可组合模块的集合
torch::optim使用 SGD、Adam 或 RMSprop 等优化算法来训练您的模型
torch::data数据集、数据管道和多线程、异步数据加载器
torch::serialize用于存储和加载模型检查点的序列化 API
torch::python将你的 C++ 模型绑定到 Python 中
torch::jit对 TorchScript JIT 编译器的纯 C++ 访问

端到端示例

这是一个简单的端到端示例,用于在 MNIST 数据集上定义和训练一个简单的神经网络:

#include <torch/torch.h>

// Define a new Module.
struct Net : torch::nn::Module
{
  Net()
  {
    // Construct and register two Linear submodules.
    fc1 = register_module("fc1", torch::nn::Linear(784, 64));
    fc2 = register_module("fc2", torch::nn::Linear(64, 32));
    fc3 = register_module("fc3", torch::nn::Linear(32, 10));
  }

  // Implement the Net's algorithm.
  torch::Tensor forward(torch::Tensor x)
  {
    // Use one of many tensor manipulation functions.
    x = torch::relu(fc1->forward(x.reshape({x.size(0), 784})));
    x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
    x = torch::relu(fc2->forward(x));
    x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
    return x;
  }

  // Use one of many "standard library" modules.
  torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};

int main()
{
  // Create a new Net.
  auto net = std::make_shared<Net>();

  // Create a multi-threaded data loader for the MNIST dataset.
  auto data_loader = torch::data::make_data_loader(
      torch::data::datasets::MNIST("./data").map(
          torch::data::transforms::Stack<>()),
      /*batch_size=*/64);

  // Instantiate an SGD optimization algorithm to update our Net's parameters.
  torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);

  for(size_t epoch = 1; epoch <= 10; ++epoch)
  {
    size_t batch_index = 0;
    // Iterate the data loader to yield batches from the dataset.
    for(auto& batch : *data_loader)
    {
      // Reset gradients.
      optimizer.zero_grad();
      // Execute the model on the input data.
      torch::Tensor prediction = net->forward(batch.data);
      // Compute a loss value to judge the prediction of our model.
      torch::Tensor loss = torch::nll_loss(prediction, batch.target);
      // Compute gradients of the loss w.r.t. the parameters of our model.
      loss.backward();
      // Update the parameters based on the calculated gradients.
      optimizer.step();
      // Output the loss and checkpoint every 100 batches.
      if(++batch_index % 100 == 0)
      {
        std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                  << " | Loss: " << loss.item<float>() << std::endl;
        // Serialize your model periodically as a checkpoint.
        torch::save(net, "net.pt");
      }
    }
  }
}

要查看使用 PyTorch C++ Frontend的更完整示例,请参阅示例存储库

使用建议

PyTorch 的 C++ Frontend的设计理念是 Python Frontend很棒,应该尽可能使用它; 但在某些设置中,性能和可移植性要求使 Python 解释器的使用不可行。例如,对于低延迟、高性能或多线程环境(如视频游戏或生产服务器),Python 是一个糟糕的选择。C++ Frontend的目标是解决这些用例,同时不牺牲 Python Frontend的用户体验。

因此,编写 C++ Frontend时考虑了一些哲学目标:

  • 在设计、命名、约定和功能方面对 Python Frontend进行密切建模。虽然这两个Frontend之间可能偶尔会有差异(例如,我们在 Python Frontend中删除了不推荐使用的功能),但我们保证将 Python 模型移植到 C++ 的努力应该完全在于翻译语言特性,而不是修改功能或行为。
  • 优先考虑灵活性和用户友好性而不是微优化。 在 C++ 中,您通常可以获得最佳代码,但代价是极其不友好的用户体验。灵活性和动态性是 PyTorch 的核心,C++ Frontend试图保留这种体验,在某些情况下会牺牲性能(或“隐藏”性能旋钮)以保持 API 的简单和可解释性。我们希望不以编写 C++ 为生的研究人员能够使用我们的 API。

一句话警告:Python 不一定比 C++ 慢! Python Frontend调用 C++ 来处理几乎任何计算成本高的事情(尤其是任何类型的数字运算),而这些运算将占用程序中花费的大部分时间。如果您更喜欢编写 Python,并且有能力编写 Python,我们建议您使用 PyTorch 的 Python 接口。但是,如果您更喜欢编写 C++,或者需要编写 C++(由于多线程、延迟或部署要求),PyTorch 的 C++ Frontend提供的 API 与 Python 对应的 API 大致一样方便、灵活、友好和直观。这两个Frontend服务于不同的用例,并肩工作,两者都不能无条件地替换另一个。

安装

有关如何安装 C++ Frontend库的说明,包括如何根据 LibTorch 构建最小应用程序的示例,可以通过该链接找到。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值