libtorch---day01[第一个Neural Network]

代码

代码参考pytorch

#include <torch/torch.h>
#include <torch/nn/module.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/functional.h>
#include <torch/optim/sgd.h>
#include <iostream>
#include "matplotlibcpp.h"

namespace plt = matplotlibcpp;

class Net : public torch::nn::Module {
public:
	Net() : Module(), 
			conv1(torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, 5))), 
			conv2(torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5))), 
			fc1(torch::nn::Linear(16 * 5 * 5, 120)), 
			fc2(torch::nn::Linear(120, 84)), 
			fc3(torch::nn::Linear(84, 50))
	{
		// 注册子模块
		conv1 = register_module("conv1", conv1);
		conv2 = register_module("conv2", conv2);
		fc1 = register_module("fc1", fc1);
		fc2 = register_module("fc2", fc2);
		fc3 = register_module("fc3", fc3);
	}
	torch::Tensor forward(torch::Tensor x) {
		// 设计隐藏层
		x = torch::relu(conv1->forward(x));
		x = torch::max_pool2d(x, 2);
		x = torch::relu(conv2->forward(x));
		x = torch::max_pool2d(x, 2);
		x = torch::flatten(x, 1);
		x = torch::relu(fc1->forward(x));
		x = torch::relu(fc2->forward(x));
		x = fc3->forward(x);
		return x;
	}
	~Net()
	{

	}
	// 网络层,设计成public成员,便于类外访问
	torch::nn::Conv2d conv1{nullptr}, conv2{ nullptr };
	torch::nn::Linear fc1{ nullptr }, fc2{ nullptr }, fc3{ nullptr };
};
int main() {
	Net net;

	// 随机生成输入,要求batch为1,channel为1,height和width分别为32
	auto input = torch::randn({ 1, 1, 32, 32 });
	// 随机生成输出,为一个长度为50的张量
	auto target = torch::randn(50);
	target = target.view({ 1, -1 });

	auto criterion = torch::nn::MSELoss();
	auto optimizer = torch::optim::SGD(net.parameters(), torch::optim::SGDOptions(0.01));

	for (int i = 0; i < 100; i++)
	{
		auto y = net.forward(input);
		auto loss = criterion(y, target);
		optimizer.zero_grad();
		loss.backward();
		optimizer.step();
		std::cout << "Epoch [" << i << "/100], Loss: " << loss.item<float>() << std::endl;
	}

	auto y = net.forward(input);
	y = y.view({ 1, -1 });
	
	// 绘制结果,这里用matplotlib-cpp,参考https://github.com/lava/matplotlib-cpp
	plt::figure_size(1200, 780);

	// 数据格式转化,主要是torch转为vector
	std::vector<float> x_plot(50);
	std::iota(x_plot.begin(), x_plot.end(), 1.);
	auto predict_continous = y.contiguous();
	auto predict_data_ptr = predict_continous.data_ptr<float>();
	auto groud_true_continous = target.contiguous();
	auto groud_true_data_ptr = groud_true_continous.data_ptr<float>();
	std::vector<float> y_plot_predict(predict_data_ptr, predict_data_ptr + predict_continous.numel());
	std::vector<float> y_plot_target(groud_true_data_ptr, groud_true_data_ptr + groud_true_continous.numel());
	// 这里只是笔者最近在复习c++八股文,所以用一下const_cast,实际应用中,只需要一个构造map即可
	const std::map<std::string, std::string> plot_map({ {"label", ""}, {"linestyle", "-"}, {"color", "red"}});
	auto& plot_map_ = const_cast<std::map<std::string, std::string>&>(plot_map);
	plot_map_["label"] = "prediction";
	plt::plot(x_plot, y_plot_predict, plot_map);
	plot_map_["label"] = "ground_true";
	plot_map_["color"] = "blue";
	plot_map_["linestyle"] = "--";
	plt::plot(x_plot, y_plot_target, plot_map);
	plt::title("simple neural network, 100 iteration");
	plt::legend();
	plt::save("SimpleNeuralNetwork.png");
	return 0;
}

关键点

定义网络层

需要显式调用

template <typename ModuleType>
std::shared_ptr<ModuleType> Module::register_module(
    std::string name,
    ModuleHolder<ModuleType> module_holder) {
  return register_module(std::move(name), module_holder.ptr());
}

进行网络层的定义。

结果

请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值