如何使用 libtorch 实现 LeNet 网络?

如何使用 libtorch 实现 LeNet 网络?

LeNet 网络论文地址:
http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

LeNet

C1 卷积层

{1,1,28,28} 是什么?

1 输入的批次
1 图像的通道大小
28 图像的高
28 图像的宽

输入:{1,1,28,28}

通过填充一个边界 2 ,使得输入变成 {1,1,32,32}

滑动窗口大小:{5,5}

输出:{1,6,32,32}

S2 降采样

输入:{1,6,32,32}

滑动窗口大小:{2,2,}
滑动步长:{2,2}

输出:{1,6,14,14}

C3 卷积层

输入:{1,16,14,14}

滑动窗口大小:{5,5}

输出:{1,16,10,10}

S4 降采样

输入:{1,16,10,10}

滑动窗口大小:{2,2,}
滑动步长:{2,2}

输出:{1,16,5,5}

C5 卷积层

输入:{1,16,5,5}

滑动窗口大小:{5,5}

输出:{1,120,1,1}

F6 全连接层

这里要把网络形状从 {1,120,1,1} 改变改变成 {1,120}

第一个全连接
输入:{1,120}
输出:{1,84}

第二个全连接
输入:{1,84}
输出:{84,10}

0~9 总共是 10 个类别嘛,这里就输出 10个就行了。

全连接就是线性层,网络形状不一样不能全连接的,所以这里要把形状改变成一样的。
基本按照那图写一遍就明白了。

关于输入和输出的网络推断公式可以去参考 pytorch 里面的函数说明,上面都有写推断公式滴。

// Define a new Module.
struct Net : torch::nn::Module {
    Net() {
        conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, /*kernel_size*/{ 5,5 }).padding(/*28->32*/{2,2})));
        conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, /*kernel_size*/{5,5})));
        conv3 = register_module("conv3", torch::nn::Conv2d(torch::nn::Conv2dOptions(16, 120, /*kernel_size*/{5,5})));
        fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(120, 84)));
        fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(84, 10)));
    }

    // Implement the Net's algorithm.
    torch::Tensor forward(torch::Tensor x) {
        x = conv1->forward(x);//6@28x28
        x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//6@14x14
        x = conv2->forward(x);//16@10x10
        x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//16@10x10
        
        x = conv3->forward(x);//120@1x1
        x = x.view({ x.size(0),-1 });//-1 表示自动推理计算出该值
        x = fc1->forward(x);//120->84
        x = fc2->forward(x);//84->10
        x = torch::log_softmax(x,/*dim=*/1);
        return x;
    }

    // Use one of many "standard library" modules.
    torch::nn::Conv2d conv1 { nullptr };
    torch::nn::Conv2d conv2 { nullptr };
    torch::nn::Conv2d conv3 { nullptr };
    torch::nn::Linear fc1{ nullptr };
    torch::nn::Linear fc2{ nullptr };
};

转载于:https://www.cnblogs.com/cheungxiongwei/p/10710968.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值