LeNet5

import  torch
from torch import nn
from torch.nn import functional as F



class Lenet5(nn.Module):
    '''
    for cifar10 dataset
    '''
    def __init__(self):
        super(Lenet5, self).__init__()

        self.conv_unit = nn.Sequential(
            # x: [b,3,32,32]
            nn.Conv2d(3,6,kernel_size=5, stride=1,padding=0),
            # =>[b,6,
            nn.AvgPool2d(kernel_size=2, stride=2,padding=0),
            #
            nn.Conv2d(6,16,kernel_size=5, stride=1,padding=0),
            #
            nn.AvgPool2d(kernel_size=2, stride=2,padding=0),
            # 打平

        )

        # flatten

        # fc unit
        self.fc_unit = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )



        # [b, 3,32,32]
        tmp = torch.randn(2,  3, 32,32)
        out = self.conv_unit(tmp)
        # [b,16,5,5]
        print("conv out:", out.shape)
        # use Cross Entropy Loss
        # nn.MSELoss
        self.criteon = nn.CrossEntropyLoss()
    def forward(self, x):
        '''

        :param x: [b,3,32,32]
        :return:
        '''
        batchsz = x.size(0)   # x.shape[0]
        # [b,3,32,32]=>[b,16,5,5]
        x = self.conv_unit(x)
        # [b,16,5,5]=>[b,16*5*5]
        # 打平
        x = x.view(batchsz, -1)
        # [b,16*5*5]=>[b,10]
        logits = self.fc_unit(x)

        # pred = F.softmax(logits, dim=1)
        # loss = self.criteon(logits,y) nn 和F 的区别, 一个要初始化,另一个直接运行
        return logits







def main():
    net = Lenet5()
    # [b, 3,32,32]
    tmp = torch.randn(2, 3, 32, 32)
    out = net(tmp)
    # [b,16,5,5]
    print("conv out:", out.shape)
    # use Cross Entropy Loss
    # nn.MSELoss



if __name__ == '__main__':
    main()

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值