导入包,定义模型
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,20,5,1)
self.conv2 = nn.Conv2d(20,50,5,1)
self.fc1 = nn.Linear(4*4*50,500)
self.fc2 = nn.Linear(500,10)
def forward(self,x):
# state_size: 28*28
x = F.relu(self.conv1(x)) # state_size: 24 * 24
x = F.max_pool2d(x,2,2) # max_pool2d 与 max_pool1d的区别 state_size: 12 * 12
x = F.relu(self.conv2(x)) # state_size: 8 * 8
x = F.max_pool2d(x, 2, 2) # state_size: 4 * 4
x = x.view(-1,4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x,dim=1)
dataloader
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else "cpu")
batch_size = 32
kwargs = {'num_worker': 0,"pin_memory":True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./mnist_data',train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle = True, **kwargs
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./mnist_data',train=False,download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False, **kwargs
)
log_interval = 100
lr = 0.01
momentum = 0.5
开始训练
def train(model,dataloader,device,optimizer,epoch,log_interval):
model.train()
for i in range(epoch):
for idx, (data,target) in enumerate(dataloader):
data,target = data.to(device),target.to(device)
optimizer.zero_grad()
output = model(data)
#SGD
loss = F.nll_loss(output,target)
loss.backward()
optimizer.step()
if idx % log_interval ==0:
print("epoch: {} [{}/{} ({:0f}%)] Loss: {:.6f}".format(
epoch,idx*len(data),len(dataloader.dataset),100.*idx/len(dataloader),loss.item()
))
def test(model,dataloader,device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output,target,reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(dataloader.dataset)
print('Test set: Average loss: {:.4f} ,accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(dataloader.dataset), 100.*correct/len(dataloader.dataset)
))
# model.train()
model = Net().to(device)
optimizer = optim.SGD(model.parameters(),lr=lr,momentum = momentum)
for i in range(1,3):
train(model,train_loader,device,optimizer,i,log_interval)
test(model,test_loader,device)
save_model = True
if save_model:
torch.save(model.state_dict(),'mnist_cnn.pt')