最近在学习pytorch,手工复现了LeNet网络,并附源码如下,欢迎大家留言交流
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, 1,)
self.conv2 = nn.Conv2d(6, 16, 5, 1)
self.fc1 = nn.Linear(16*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
#x:1*28*28
x = F.max_pool2d(self.conv1(x), 2, 2)
x = F.max_pool2d(self.conv2(x), 2, 2)
x = x.view(-1, 16*4*4)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return F.log_softmax(x,dim=1)
def train(model, device, train_loader, optimizer, epoch):
model.train()
for idx, (data, target) in enumerate(train_loader):
data, target =