main函数
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn ,optim
from lenet5 import Lenet5
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
def main():
batchsz= 32
"""
'cifar' 就是在当前目录下面新建一个叫做cifar的文件夹,train=ture ??(是否训练嘛?)
transform 就是对数据集做一些变换
resize 大小维度的转换
totensor 转换成为 tensor
"""
cifar_train =datasets.CIFAR10('cifar',True,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True)
cifar_train= DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test =datasets.CIFAR10('cifar',False,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True)
cifar_test= DataLoader(cifar_test,batch_size=batchsz,shuffle=True)
x,label=iter(cifar_train).next()
print('x:',x.shape,'label:',label.shape)
device=torch.device('cuda')
model=Lenet5().to(device)
criteon=nn.CrossEntropyLoss().to(device)
optimizer =optim.Adam(model.parameters(),lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchsz,(x,label) in enumerate(cifar_train):
x,label=x.to(device),label.to(device)
logits=model(x)
loss=criteon(logits,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch,loss.item())
model.eval()
with torch.no_grad():
total_num=0
total_correct=0
for x,label in cifar_test:
x,label=x.to(device),label.to(device)
logits=model(x)
pred=logits.argmax(dim=1)
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
acc=total_correct/total_num
print("当前epoch :{} ,当前准确率:{}".format(epoch,acc))
if __name__ == '__main__':
main()
LENET文件
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
"""
for cifar10 dataset.
"""
def __init__(self):
super(Lenet5,self).__init__()
self.conv_unit=nn.Sequential(
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
)
self.fc_unit =nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
"""
下面这段代码用于测试conv_unit 的输出的维度形状
[b,3,32,32]
"""
tmp=torch.randn(2,3,32,32)
out=self.conv_unit(tmp)
print("out.shape is:",out.shape)
self.criteon =nn.CrossEntropyLoss
def forward(self,x):
"""
:param x:[]
:return:
"""
batchsz=x.size(0)
x=self.conv_unit(x)
x=x.view(batchsz,-1)
logits=self.fc_unit(x)
return logits
def main():
net=Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet out:',out.shape)
if __name__=='__main__':
main()
----------------------------------------
----------------------------------------
----------------------------------------
----------------------------------------
运行结果:
0 1.2656571865081787
当前epoch :0 ,当前准确率:0.4524
1 1.4220787286758423
当前epoch :1 ,当前准确率:0.4927
2 0.9537531137466431
当前epoch :2 ,当前准确率:0.5178
3 1.113728404045105
当前epoch :3 ,当前准确率:0.5216
4 1.3806077241897583
当前epoch :4 ,当前准确率:0.5441
5 0.9763231873512268
当前epoch :5 ,当前准确率:0.5368
6 0.813972532749176
当前epoch :6 ,当前准确率:0.5436
7 0.949809193611145
当前epoch :7 ,当前准确率:0.5496
8 1.1697014570236206
当前epoch :8 ,当前准确率:0.538
9 1.0288864374160767
当前epoch :9 ,当前准确率:0.5554
10 1.2099034786224365
当前epoch :10 ,当前准确率:0.5553
11 2.150331497192383
当前epoch :11 ,当前准确率:0.5511
12 1.0876156091690063
当前epoch :12 ,当前准确率:0.5483
13 0.7157190442085266
当前epoch :13 ,当前准确率:0.5526
14 0.6226208806037903
当前epoch :14 ,当前准确率:0.5489
15 0.8635637164115906
当前epoch :15 ,当前准确率:0.5476
16 0.7294909954071045
当前epoch :16 ,当前准确率:0.552
17 1.023239254951477
当前epoch :17 ,当前准确率:0.5402
18 0.8226995468139648
当前epoch :18 ,当前准确率:0.5475
19 0.6349995732307434
当前epoch :19 ,当前准确率:0.5425
20 0.49438467621803284
当前epoch :20 ,当前准确率:0.5468
21 1.0320965051651
当前epoch :21 ,当前准确率:0.5395