libtorch---day03[自定义导数]

参考pytorch

背景

希望使用勒让德多项式拟合一个周期内的正弦函数。
真值: y = s i n ( x ) , x ∈ [ − π , π ] y=sin(x),x\in\left[-\pi,\pi\right] y=sin(x),x[π,π]

torch::Tensor x = torch::linspace(-M_PI, M_PI, 2000, torch::kFloat);
torch::Tensor y = torch::sin(x);

预测值是 n = 3 n=3 n=3的勒让德多多项式: y ^ = a + b × P 3 ( c + d x ) \hat{y} = a+b\times P_3(c+dx) y^=a+b×P3(c+dx),其中 P 3 ( x ) = 1 2 ( 5 x 3 − 3 x ) P_3(x) = \frac{1}{2}(5x^3-3x) P3(x)=21(5x33x)

构造自动求导类

torch提供了一种可以让开发者自主定义前向传播和后向求导的机制:

1、写一个类,继承torch::autograd::Function
2、在类中定义静态的forwardbackward函数,必须是静态的,这样在调用torch::autograd::Function::applytorch::autograd::Function::backward的时候,会自动调用上述两个静态函数;

struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{
	static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input)
	{
		ctx->save_for_backward({ input });
		return 0.5 * (5 * torch::pow(input, 3) - 3 * input);
	}

	static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output)
	{
		auto saved = ctx->get_saved_variables();
		torch::Tensor input = saved[0];

		torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);

		return { grad_input };
	}
};

关键点

  • 必须显式调用**ctx->save_for_backward({ input });保存节点信息、调用auto saved = ctx->get_saved_variables();**获取保存的节点信息;
  • forward函数计算的是预测值,这个和认知里的forward的功能相同;
  • backward函数的输入是grad_output,是损失项关于输出的梯度 ∂ L ∂ y \frac{\partial L}{\partial y} yL,而backward计算的是损失函数关于输入的梯度 ∂ L ∂ x \frac{\partial L}{\partial x} xL,因此需要计算 ∂ L ∂ x = ∂ L ∂ y × ∂ y ∂ x \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\times \frac{\partial y}{\partial x} xL=yL×xy
  • 必须要注意backwardforward的参数列表必须固定;

全部代码

#include <torch/torch.h>
#include <iostream>
#include "matplotlibcpp.h"

struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{
	static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input)
	{
		ctx->save_for_backward({ input });
		return 0.5 * (5 * torch::pow(input, 3) - 3 * input);
	}

	static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output)
	{
		auto saved = ctx->get_saved_variables();
		torch::Tensor input = saved[0];

		torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);

		return { grad_input };
	}
};
void plot_tensor_xy_compare(const torch::Tensor x, const torch::Tensor y, const torch::Tensor predict)
{
	auto data_ptr = x.data_ptr<float>();
	std::vector<float> x_vector(data_ptr, data_ptr + x.numel());
	data_ptr = y.data_ptr<float>();
	std::vector<float> y_vector(data_ptr, data_ptr + y.numel());
	data_ptr = predict.data_ptr<float>();
	std::vector<float> predict_vector(data_ptr, data_ptr + predict.numel());

	std::map<std::string, std::string> key_words({ {"label", "ground_true"}, {"color", "blue"}, {"linestyle", "-"}});
	matplotlibcpp::plot(x_vector, y_vector, key_words);
	key_words["color"] = "red";
	key_words["linestyle"] = "--";
	key_words["label"] = "prediction";
	matplotlibcpp::plot(x_vector, predict_vector, key_words);
	matplotlibcpp::grid(true);
	matplotlibcpp::legend();
	matplotlibcpp::show();
}
int main()
{
	
	torch::Tensor x = torch::linspace(-M_PI, M_PI, 1000, torch::kFloat);
	torch::Tensor y = torch::sin(x);

	torch::Tensor a = torch::full({}, 0., torch::kFloat).set_requires_grad(true);
	torch::Tensor b = torch::full({}, -1., torch::kFloat).set_requires_grad(true);
	torch::Tensor c = torch::full({}, 0., torch::kFloat).set_requires_grad(true);
	torch::Tensor d = torch::full({}, 0.3, torch::kFloat).set_requires_grad(true);
	double learning_rate = 5e-6;
	torch::nn::MSELoss criterion;
	torch::optim::SGD optimizer({a, b, c, d}, torch::optim::SGDOptions(learning_rate));
	for (int i = 0; i < 2000; i++)
	{
		auto P3 = LegenderPolynominal3::apply(c + d * x);
		torch::Tensor predict = a + b * P3;

		torch::Tensor loss = (predict - y).pow(2).sum();
		// auto loss = criterion(predict, y);
		loss.backward();

		optimizer.step();
		optimizer.zero_grad();

		std::cout << "iteration: " << i + 1 << "/2000" << ", loss: " << loss.item<double>() << std::endl;
	}
	auto P3 = LegenderPolynominal3::apply(c + d * x);
	torch::Tensor predict = a + b * P3;

	plot_tensor_xy_compare(x, y, predict);
	return 0;
}

结果

在这里插入图片描述

相应的nn模块

#include <torch/torch.h>
#include "matplotlibcpp.h"

using namespace torch;
void plot_tensor_xy_compare(const torch::Tensor x, const torch::Tensor y, const torch::Tensor predict)
{
	auto data_ptr = x.data_ptr<float>();
	std::vector<float> x_vector(data_ptr, data_ptr + x.numel());
	data_ptr = y.data_ptr<float>();
	std::vector<float> y_vector(data_ptr, data_ptr + y.numel());
	data_ptr = predict.data_ptr<float>();
	std::vector<float> predict_vector(data_ptr, data_ptr + predict.numel());

	std::map<std::string, std::string> key_words({ {"label", "ground_true"}, {"color", "blue"}, {"linestyle", "-"} });
	matplotlibcpp::plot(x_vector, y_vector, key_words);
	key_words["color"] = "red";
	key_words["linestyle"] = "--";
	key_words["label"] = "prediction";
	matplotlibcpp::plot(x_vector, predict_vector, key_words);
	matplotlibcpp::grid(true);
	matplotlibcpp::legend();
	matplotlibcpp::show();
}
class auto_grad : public nn::Module
{
public:
	Tensor a, b, c, d;
	auto_grad() : a(torch::full({}, 0., kFloat).set_requires_grad(true)),
		b(torch::full({}, -1., kFloat).set_requires_grad(true)),
		c(torch::full({}, 0., kFloat).set_requires_grad(true)),
		d(torch::full({}, 0.3, kFloat).set_requires_grad(true))
	{
		register_parameter("a", a);
		register_parameter("b", b);
		register_parameter("c", c);
		register_parameter("d", d);
	}
	Tensor forward(Tensor input)
	{
		auto P3 = c + d * input;
		return a + b * (0.5 * (5 * torch::pow(P3, 3) - 3 * P3));
	}
};
int main()
{
	auto_grad net;
	nn::MSELoss criterion;
	optim::SGDOptions opt(1e-5);
	opt.momentum(0.9);
	optim::SGD optim(net.parameters(), opt);

	torch::Tensor x = torch::linspace(-M_PI, M_PI, 1000, torch::kFloat);
	torch::Tensor y = torch::sin(x);
	int iteration = 1000;
	for (int i = 0; i < iteration; i++)
	{
		auto predict = net.forward(x);
		auto loss = (predict - y).pow(2).sum();

		loss.backward();
		optim.step();
		optim.zero_grad();
		printf("[training iteration: %d/ %d, loss: %lf]\n", i +1, iteration, loss.item<double>());
	}
	auto predict = net.forward(x);
	plot_tensor_xy_compare(x, y, predict);
	return 0;
}

关键点

1、使用register_parameter显式注册参数;

  • 23
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值