L
i
b
T
o
r
c
h
实现
L
e
N
e
t
LibTorch实现LeNet
LibTorch实现LeNet
#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h>
using namespace std;
class LeNet :public torch::nn::Module {
public:
LeNet(int num_classes,int num_linear);
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Conv2d conv1{nullptr};
torch::nn::Conv2d conv2{nullptr};
torch::nn::Linear fc1{ nullptr };
torch::nn::Linear fc2{nullptr};
torch::nn::Linear fc3{nullptr};
};
LeNet::LeNet(int num_classes, int num_linear)
{
conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5)));
conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5)));
fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128)));
fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32)));
fc3 = register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));
}
torch::Tensor LeNet::forward(torch::Tensor x)
{
auto out = torch::relu(conv1->forward(x));
out = torch::max_pool2d(out, 2);
out = torch::relu(conv2(out));
out = torch::max_pool2d(out, 2);
out = out.view({ 1, -1 });
out = torch::relu(fc1(out));
out = torch::relu(fc2(out));
out = fc3(out);
return out;
}
int main()
{
auto device = torch::Device(torch::kCUDA, 0);
auto input = torch::ones({1,3,224,224});
cout << input.sizes() <<endl ;
auto net = LeNet(5,44944);
try
{
auto output = net.forward(input);
cout << output << endl;
}
catch (const std::exception& e)
{
cout << e.what() << endl;
}
return 0;
}