首先介绍CIFAR-10数据集
第一步,写数据库处理
import torch
from mne import label
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
#新建一个main函数,加载CIFAR-10数据集
def main():
batchza=32
cifar_train=datasets.CIFAR10('cifar',True,transform=transforms.Compose([
transforms.Resize((32,32)), # 把照片改成我需要的大小
transforms.ToTensor() # 转换为tensor
]),download=True) # transform,代表要做的一些变化
cifar_train=DataLoader(cifar_train,batch_size=batchza,shuffle=True)
cifar_test=datasets.CIFAR10('cifar',False,transform=transforms.Compose([
transforms.Resize((32,32)), # 把照片改成我需要的大小
transforms.ToTensor() # 转换为tensor
]),download=True) # transform,代表要做的一些变化
cifar_test=DataLoader(cifar_test,batch_size=batchza,shuffle=True)
# iter()可以用来得到dataload的迭代器,然后用迭代器的next方法得到一个batch
x,label=iter(cifar_train).next()
print('x:', x.shape,'label:',label.shape)
if __name__ == '__main__':
main()
然后新建一个module,写LeNet-5
import torch
from torch import nn
from torch.nn import functional as F
class LeNet5(nn.Module):
'''
for CIFAR-10
'''
def __init__(self):
#调用类初始化方法,初始化父类
super(LeNet5,self).__init__()
#然后查询需要用的网络结构,进行写
# 把网络写在Sequential里面,可以非常方便组织结构
self.conv_unit=nn.Sequential(
# 3,代表彩色,卷积核一般1-7
# x:[b,3,32,32]=>[b,6, , ]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
#
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.MaxPool2d(kernel_size=2,strid