LibTorch实现LeNet

L i b T o r c h 实现 L e N e t LibTorch实现LeNet LibTorch实现LeNet

#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h> 
using namespace std;

/// <summary>
/// LeNet实现类
/// </summary>
class LeNet :public torch::nn::Module {
public:
	// 构造器
	LeNet(int num_classes,int num_linear);
	// 前向传播
	torch::Tensor forward(torch::Tensor x);
private:
	// 具体实现放到构造器实现中
	torch::nn::Conv2d conv1{nullptr};
	torch::nn::Conv2d conv2{nullptr};
	torch::nn::Linear fc1{ nullptr };
	torch::nn::Linear fc2{nullptr};
	torch::nn::Linear fc3{nullptr};
};

LeNet::LeNet(int num_classes, int num_linear)
{
	conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5)));
	conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5)));
	fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128)));
	fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32)));
	fc3 = register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));
}


torch::Tensor LeNet::forward(torch::Tensor x)
{
	auto out = torch::relu(conv1->forward(x));
	out = torch::max_pool2d(out, 2);
	out = torch::relu(conv2(out));
	out = torch::max_pool2d(out, 2);
	out = out.view({ 1, -1 });
	out = torch::relu(fc1(out));
	out = torch::relu(fc2(out));
	out = fc3(out);
	return out;
}


int main()
{

	//step0:定义使用cuda
	auto device = torch::Device(torch::kCUDA, 0);

	// step1:生成测试数据
	auto input = torch::ones({1,3,224,224});
	cout << input.sizes() <<endl ;

	// step2:生成网络层实例
	
	auto net = LeNet(5,44944);
	
	// step3:推理输出
	try
	{
		auto output = net.forward(input);
		// step4:打印输出和大小
		cout << output << endl;
	}
	catch (const std::exception& e)
	{
		// step5:打印报错
		cout << e.what() << endl;
	}
	
	return 0;
}

在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值