import torch
from torchvision import transforms
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
epoch_num = 10
lr = 0.1
momentum = 0.5
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])
train_set = mnist.MNIST('.\data',train = True,transform = transform,download = False)
test_set = mnist.MNIST('.\data',train = False,transform = transform,download = False)
train_data = DataLoader(train_set,batch_size = 16,shuffle = True)
test_data = DataLoader(test_set,batch_size = 16,shuffle = True)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.lawyer1 = nn.Sequential(nn.Conv2d(1,16,4),nn.BatchNorm2d(16),nn.ReLU(True))
self.pool1 = nn.MaxPool2d(2,2)
self.lawyer2 = nn.Sequential(nn.Conv2d(16,36,4),nn.BatchNorm2d(36),nn.ReLU(True))
self.pool2 = nn.MaxPool2d(2,2)
self.aap = nn.AdaptiveAvgPool2d(1)
self.line1 = nn.Sequential(nn.Linear(36,100),nn.BatchNorm1d(100),nn.ReLU(True))
self.line2 = nn.Linear(100,10)
def forward(self,x):
x = self.lawyer1(x)
x = self.pool1(x)
x = self.lawyer2(x)
x = self.pool2(x)
x = self.aap(x)
x = x.view(x.shape[0],-1)
x = self.line1(x)
x = self.line2(x)
return x
module = Net()
device = torch.device('cpu')
module=module.to(device)
optimizer = optim.SGD(module.parameters(),lr = lr, momentum=momentum)
loss_fun = nn.CrossEntropyLoss()
for epoch in range(epoch_num):
if(epoch % 2 ==0):
optimizer.param_groups[0]['lr'] *= 0.6
module.train()
for img,label in train_data:
img = img.to(device)
label = label.to(device)
out = module(img)
loss = loss_fun(out,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
module.eval()
test_loss = 0
test_accury = 0
with torch.no_grad():
for img,label in test_data:
img = img.to(device)
label = label.to(device)
out = module(img)
loss = loss_fun(out,label)
test_loss+= loss.item()/img.shape[0]
_,pre = out.max(1)
test_accury+= (pre==label).sum().item()/img.shape[0]
print('epoch :{} test_loss:{:.4f} test_accury:{:.4f} '.format(epoch,test_loss/len(test_data),test_accury/len(test_data)))
基于卷积神经网络的手写数字识别
最新推荐文章于 2023-11-29 14:17:31 发布