import torch
import torch.nn as nn
import torchvision.datasets as normal_datasets
import torchvision.transforms as transforms
train_dataset = normal_datasets.MNIST(
root='./mnist/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = normal_datasets.MNIST(
root='./mnist/',
train=False,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=100,
shuffle=False)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.MaxPool2d(2))
self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2))
self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=1))
self.fc = nn.Linear(6 * 6 * 32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
cnn=CNN()
cnn = cnn.cuda()
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
for epoch in range(5):
train_loss = 0.0
train_correct = 0.0
test_loss = 0.0
test_correct = 0.0
cnn.train()
for i, (images, labels) in enumerate(train_loader):
images = images.cuda()
labels = labels.cuda()
outputs = cnn(images)
loss = loss_func(outputs, labels)
_, pred = torch.max(outputs, 1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_correct += torch.sum(pred == labels).item()
cnn.eval()
for i, (images, labels) in enumerate(test_loader):
images = images.cuda()
labels = labels.cuda()
outputs = cnn(images)
loss = loss_func(outputs, labels)
test_loss += loss.item()
_, pred = torch.max(outputs, 1)
test_correct += torch.sum(pred == labels).item()
print("Epoch:{}, Train_loss:{:.5f}, Train Accuracy:{:.2f}%, "
"Test_loss:{:.5f}, Test Accuracy:{:.2f}%".format(
epoch + 1,train_loss/len(train_dataset),100 * float(train_correct) / len(train_dataset),
test_loss / len(test_dataset),100 * float(test_correct) / len(test_dataset)))
Pytorch:MNIST十分类
最新推荐文章于 2023-02-08 15:28:13 发布