libtorch c++ 定义全链接网络

目录

1. main函数

2. 搭建MLP网络

2.1 mlp.h

2.2 mlp.cpp

1. main函数

(1)实例化网络,该网络模型必须继承torch::nn::Module类;

(2)获取训练数据,其中输入数据维度和目标数据维度都是(b,n);

(3)实例化优化器,这里用的是Adam,学习率是0.0005;

(4)forward,在求mse_loss,backward,step。

#include<torch/script.h>
#include<torch/torch.h>
#include"mlp.h"

int main()
{
	MLP mlp(20, 2);  // model
	at::Tensor input_x = torch::rand({ 4, 20 });  // data
	at::Tensor input_y = torch::ones({ 4, 2 });
	torch::optim::Adam optimizer(mlp.parameters(), 0.001);
	for (int epoch = 0; epoch < 100; epoch++)  // train
	{
		optimizer.zero_grad();
		at::Tensor output = mlp.forward(input_x);
		at::Tensor loss = torch::mse_loss(output, input_y);
		loss.backward();
		optimizer.step();
		std::cout << loss.item().toFloat() << std::endl;
	}
}

out shape: [2, 1]; target shape:[2, 1]

2. 搭建MLP网络

MLP多层感知机,也就是全连接网络.

2.1 mlp.h

#ifndef MLP_H
#define MLP_H
#endif // MLP_H

#include <torch/torch.h>
#include <torch/script.h>

// 小模块:fc+bn+relu
class LinearBnReluImpl : public torch::nn::Module {
public:
    LinearBnReluImpl(int intput_features, int output_features);
    torch::Tensor forward(torch::Tensor x);
private:
    //layers
    torch::nn::Linear ln{ nullptr };  // 定义私有成员,先构造函数初始化,再在forward函数使用。
    torch::nn::BatchNorm1d bn{ nullptr };
};
TORCH_MODULE(LinearBnRelu);


class MLP : public torch::nn::Module {
public:
    MLP(int in_features, int out_features);  // 构造函数:输入特征维度,和输出特征维度
    torch::Tensor forward(torch::Tensor x);  // 推理函数
private:
    int mid_features[3] = { 32,64,128 };       // 中间层特征维度
    LinearBnRelu ln1{ nullptr };               // 3个(linear + bn + relu)
    LinearBnRelu ln2{ nullptr };
    LinearBnRelu ln3{ nullptr };
    torch::nn::Linear out_ln{ nullptr };       // 普通的linear层
};

2.2 mlp.cpp

3个linear+bn+relu,最后接一个linear.

#include "mlp.h"

// 实现LinearBnRelu
// 注册线性层、bn层
LinearBnReluImpl::LinearBnReluImpl(int in_features, int out_features) {
    ln = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(in_features, out_features)));
    // 注意bn操作时,训练时batch_size必须设置成大于1,否则没意义且会报错,测试时会屏蔽此操作
    bn = register_module("bn", torch::nn::BatchNorm1d(out_features));
}
// linear->relu->bn
torch::Tensor LinearBnReluImpl::forward(torch::Tensor x) {
    x = torch::relu(ln->forward(x));
    x = bn(x);
    return x;
}

MLP::MLP(int in_features, int out_features) {
    ln1 = LinearBnRelu(in_features, mid_features[0]);  // 初始化
    ln2 = LinearBnRelu(mid_features[0], mid_features[1]);
    ln3 = LinearBnRelu(mid_features[1], mid_features[2]);
    out_ln = torch::nn::Linear(mid_features[2], out_features);

    ln1 = register_module("ln1", ln1);  // 构造函数注册轮子
    ln2 = register_module("ln2", ln2);
    ln3 = register_module("ln3", ln3);
    out_ln = register_module("out_ln", out_ln);
}

torch::Tensor MLP::forward(torch::Tensor x) {
    x = ln1->forward(x);   // 逐个forward,因为每个都是module,有各自的forward函数。
    x = ln2->forward(x);   // 
    x = ln3->forward(x);
    x = out_ln->forward(x);
    return x;
}

loss逐渐收敛

 

参考:LibtorchTutorials/lesson3-BasicModels at main · AllentDan/LibtorchTutorials · GitHub

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值