LeNet网络的代码实现pytroch
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
transform=torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
#训练数据集
train_dataset=torchvision.datasets.CIFAR10("dataset",train=True,transform=transform,download=True)
train_loader=DataLoader(train_dataset,batch_size=100,shuffle=True,num_workers=0)
#测试数据集
testset=torchvision.datasets.CIFAR10("dataset",train=False,transform=transform,download=True)
testloader=DataLoader(testset,batch_size=10000,shuffle=True,num_workers=0)
test_data_iter=iter(testloader)
test_image,test_label=test_data_iter.__next__()
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1=nn.Conv2d(3,16,5)
self.pool1=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(16,32,5)
self.pool2=nn.MaxPool2d(2,2)
self.linear1=nn.Linear(32*5*5,120)
self.linear2=nn.Linear(120,84)
self.linear3=nn.Linear(84,10)
def forward(self,x):
x=F.relu(self.conv1(x))
x=self.pool1(x)
x=F.relu(self.conv2(x))
x=self.pool2(x)
x=x.view(-1,32*5*5)
x=F.relu(self.linear1(x))
x=F.relu(self.linear2(x))
x=self.linear3(x)
return x
net=LeNet().cuda()
loss_function= nn.CrossEntropyLoss().cuda()
optimizer=torch.optim.Adam(net.parameters(),lr=0.001)
writer=SummaryWriter("keshihua")
step=0
#开始训练
for epoch in range(10):
running_loss=0.0
for dataloader in train_loader:
inputs,labels=dataloader
inputs=inputs.cuda()
labels=labels.cuda()
optimizer.zero_grad()
outputs=net(inputs)
loss = loss_function(outputs,labels)
loss.backward()
optimizer.step()
running_loss+=loss.item()
writer.add_scalar("train_loss",loss,step)
if step %500==0:
with torch.no_grad():
test_image=test_image.cuda()
outputs=net(test_image)
outputs=outputs.cuda()
predict_y=torch.max(outputs,dim=1)[1]
predict_y=predict_y.cuda()
test_label=test_label.cuda()
accuracy= (predict_y==test_label).sum().item()/test_label.size(0)
print("step={0},train_loss={1},test_accuracy={2}".format(step+1,running_loss,accuracy))
writer.add_scalar("test_accuracy",accuracy, step)
step+=1
writer.close()
print("训练完成")
save_path='LeNet.pth'
torch.save(net.state_dict(),save_path)
训练误差以及测试集的精度(数据集比较老,b站博主泡跑也是68的精确率)
tensorboard可视化误差和精度