L
i
b
T
o
r
c
h
实现
M
L
P
(多层感知机)
LibTorch实现MLP(多层感知机)
LibTorch实现MLP(多层感知机)
#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h>
using namespace std;
class LinearReluImpl :public torch::nn::Module {
public:
LinearReluImpl(int input, int output);
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Linear linear1{ nullptr };
};
LinearReluImpl::LinearReluImpl(int input,int output) {
linear1 = register_module("linear1", torch::nn::Linear(torch::nn::LinearOptions(input, output)));
}
torch::Tensor LinearReluImpl::forward(torch::Tensor x) {
x = torch::relu(linear1(x));
return x;
}
TORCH_MODULE(LinearRelu);
class MLP :public torch::nn::Module {
public:
MLP(int in_put,int out_put);
torch::Tensor forward(torch::Tensor x);
private:
int mid_features[3] = { 32,64,128 };
LinearRelu ln1{ nullptr};
LinearRelu ln2{ nullptr };
LinearRelu ln3{ nullptr };
torch::nn::Linear out_ln{ nullptr };
};
MLP::MLP(int in_features, int out_features) {
ln1 = LinearRelu(in_features, mid_features[0]);
ln2 = LinearRelu(mid_features[0], mid_features[1]);
ln3 = LinearRelu(mid_features[1], mid_features[2]);
out_ln = torch::nn::Linear(mid_features[2], out_features);
ln1 = register_module("ln1", ln1);
ln2 = register_module("ln2", ln2);
ln3 = register_module("ln3", ln3);
out_ln = register_module("out_ln", out_ln);
}
torch::Tensor MLP::forward(torch::Tensor x) {
x = ln1->forward(x);
x = ln2->forward(x);
x = ln3->forward(x);
x = out_ln->forward(x);
return x;
}
int main()
{
auto device = torch::Device(torch::kCUDA, 0);
auto input = torch::ones({100});
cout << input.sizes() << " "<< input<<endl ;
auto net = MLP(100,10);
try
{
auto output = net.forward(input);
cout << output << endl;
}
catch (const std::exception& e)
{
cout << e.what() << endl;
}
return 0;
}