代码
代码参考pytorch
#include <torch/torch.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/functional.h>
#include <torch/optim/sgd.h>
#include <iostream>
#include "matplotlibcpp.h"
namespace plt = matplotlibcpp;
class Net : public torch::nn::Module {
public:
Net() : Module(),
conv1(torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, 5))),
conv2(torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5))),
fc1(torch::nn::Linear(16 * 5 * 5, 120)),
fc2(torch::nn::Linear(120, 84)),
fc3(torch::nn::Linear(84, 50))
{
// 注册子模块
conv1 = register_module("conv1", conv1);
conv2 = register_module("conv2", conv2);
fc1 = register_module("fc1", fc1);
fc2 = register_module("fc2", fc2);
fc3 = register_module("fc3", fc3);
}
torch::Tensor forward(torch::Tensor x) {
// 设计隐藏层
x = torch::relu(conv1->forward(x));
x = torch::max_pool2d(x, 2);
x = torch::relu(conv2->forward(x));
x = torch::max_pool2d(x, 2);
x = torch::flatten(x, 1);
x = torch::relu(fc1->forward(x));
x = torch::relu(fc2->forward(x));
x = fc3->forward(x);
return x;
}
~Net()
{
}
// 网络层,设计成public成员,便于类外访问
torch::nn::Conv2d conv1{nullptr}, conv2{ nullptr };
torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};
int main() {
Net net;
// 随机生成输入,要求batch为1,channel为1,height和width分别为32
auto input = torch::randn({ 1, 1, 32, 32 });
// 随机生成输出,为一个长度为50的张量
auto target = torch::randn(50);
target = target.view({ 1, -1 });
auto criterion = torch::nn::MSELoss();
auto optimizer = torch::optim::SGD(net.parameters(), torch::optim::SGDOptions(0.01));
for (int i = 0; i < 100; i++)
{
auto y = net.forward(input);
auto loss = criterion(y, target);
optimizer.zero_grad();
loss.backward();
optimizer.step();
std::cout << "Epoch [" << i << "/100], Loss: " << loss.item<float>() << std::endl;
}
auto y = net.forward(input);
y = y.view({ 1, -1 });
// 绘制结果,这里用matplotlib-cpp,参考https://github.com/lava/matplotlib-cpp
plt::figure_size(1200, 780);
// 数据格式转化,主要是torch转为vector
std::vector<float> x_plot(50);
std::iota(x_plot.begin(), x_plot.end(), 1.);
auto predict_continous = y.contiguous();
auto predict_data_ptr = predict_continous.data_ptr<float>();
auto groud_true_continous = target.contiguous();
auto groud_true_data_ptr = groud_true_continous.data_ptr<float>();
std::vector<float> y_plot_predict(predict_data_ptr, predict_data_ptr + predict_continous.numel());
std::vector<float> y_plot_target(groud_true_data_ptr, groud_true_data_ptr + groud_true_continous.numel());
// 这里只是笔者最近在复习c++八股文,所以用一下const_cast,实际应用中,只需要一个构造map即可
const std::map<std::string, std::string> plot_map({ {"label", ""}, {"linestyle", "-"}, {"color", "red"}});
auto& plot_map_ = const_cast<std::map<std::string, std::string>&>(plot_map);
plot_map_["label"] = "prediction";
plt::plot(x_plot, y_plot_predict, plot_map);
plot_map_["label"] = "ground_true";
plot_map_["color"] = "blue";
plot_map_["linestyle"] = "--";
plt::plot(x_plot, y_plot_target, plot_map);
plt::title("simple neural network, 100 iteration");
plt::legend();
plt::save("SimpleNeuralNetwork.png");
return 0;
}
关键点
定义网络层
需要显式调用
template <typename ModuleType>
std::shared_ptr<ModuleType> Module::register_module(
std::string name,
ModuleHolder<ModuleType> module_holder) {
return register_module(std::move(name), module_holder.ptr());
}
进行网络层的定义。