import time
import torch
from torch import nn, optim
import sys
sys.path.append("…")
import d2lzh_pytorch as d2l
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
class LeNet(nn.Module):
def init(self):
super(LeNet, self).init()
self.conv = nn.Sequential(
nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
nn.Sigmoid(),
nn.MaxPool2d(2, 2), # kernel_size, stride
nn.Conv2d(6, 16, 5),
nn.Sigmoid(),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(1644, 120),
nn.Sigmoid(),
nn.Linear(120, 84),
nn.Sigmoid(),
nn.Linear(84, 10)
)
def forward(self, img):
feature = self.conv(img)
output = self.fc(feature.view(img.shape[0], -1))
return output
def evaluate_accuracy(data_iter, net, device=None):<