CIFAR-10程序

首先介绍CIFAR-10数据集

第一步,写数据库处理

import torch
from mne import label
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms

#新建一个main函数,加载CIFAR-10数据集
def main():
    batchza=32
    cifar_train=datasets.CIFAR10('cifar',True,transform=transforms.Compose([
        transforms.Resize((32,32)), # 把照片改成我需要的大小
        transforms.ToTensor()  #  转换为tensor
    ]),download=True)  # transform,代表要做的一些变化
    cifar_train=DataLoader(cifar_train,batch_size=batchza,shuffle=True)


    cifar_test=datasets.CIFAR10('cifar',False,transform=transforms.Compose([
        transforms.Resize((32,32)), # 把照片改成我需要的大小
        transforms.ToTensor()  #  转换为tensor
    ]),download=True)  # transform,代表要做的一些变化
    cifar_test=DataLoader(cifar_test,batch_size=batchza,shuffle=True)


# iter()可以用来得到dataload的迭代器,然后用迭代器的next方法得到一个batch
    x,label=iter(cifar_train).next()
    print('x:', x.shape,'label:',label.shape)




if __name__ == '__main__':
    main()

然后新建一个module,写LeNet-5

import torch
from torch import nn
from torch.nn import functional as F



class LeNet5(nn.Module):
    '''
    for CIFAR-10
    '''
    def __init__(self):
        #调用类初始化方法,初始化父类
        super(LeNet5,self).__init__()
        #然后查询需要用的网络结构,进行写
# 把网络写在Sequential里面,可以非常方便组织结构
        self.conv_unit=nn.Sequential(
            # 3,代表彩色,卷积核一般1-7
            # x:[b,3,32,32]=>[b,6, , ]
            nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
            #
            nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
            nn.MaxPool2d(kernel_size=2,strid
  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值