多分类问题之MNIST手写数字识别
0.导入所需要的库
import torch
from torchvision import transforms
from torchvision import datasets
from torch. utils. data import DataLoader
import torch. optim as optim
import torch. nn. functional as F
1.准备数据集
batch_size = 64
transform = transforms. Compose( [ transforms. ToTensor( ) , transforms. Normalize( ( 0.1307 , ) , ( 0.3081 ) ) ] )
train_dataset = datasets. MNIST( root= './data/' , train= True , download= True , transform= transform)
train_loader = DataLoader( train_dataset, shuffle= True , batch_size= batch_size)
test_dataset = datasets. MNIST( root= './data/' , train= False , download= True , transform= transform)
test_loader = DataLoader( test_dataset, shuffle= False , batch_size= batch_size)
2.设计模型
class Net ( torch. nn. Module) :
def __init__ ( self) :
super ( ) . __init__( )
self. l1 = torch. nn. Linear( 784 , 512 )
self. l2 = torch. nn. Linear( 512 , 256 )
self. l3 = torch. nn. Linear( 256 , 128 )
self. l4 = torch. nn. Linear( 128 , 64 )
self. l5 = torch. nn. Linear( 64 , 10 )
def forward ( self, x) :
x = x. view( - 1 , 784 )
x = F. relu( self. l1( x) )
x = F. relu( self. l2( x) )
x = F. relu( self. l3( x) )
x = F. relu( self. l4( x) )
return self. l5( x)
model = Net( )
3.定义损失和优化器
criterion = torch. nn. CrossEntropyLoss( )
optimizer = optim. SGD( model. parameters( ) , lr= 0.01 )
4.训练
def train ( epoch) :
running_loss = 0
for batch_idx, data in enumerate ( train_loader, 0 ) :
inputs, target = data
outputs = model( inputs)
loss = criterion( outputs, target)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
running_loss += loss. item( )
if batch_idx % 300 == 299 :
print ( f'epoch: { epoch} ,batch_idx: { batch_idx} ,loss: { running_loss/ ( batch_idx+ 1 ) } ' )
5.测试
def test ( ) :
correct = 0
total = 0
for batch_idx, data in enumerate ( test_loader, 0 ) :
inputs, target = data
outputs = model( inputs)
_, predicted = torch. max ( outputs. data, dim= 1 )
total += target. size( 0 )
correct += ( predicted == target) . sum ( ) . item( )
print ( f'正确率:% { 100 * correct/ total} ' )
6.运行
if __name__ == '__main__' :
for epoch in range ( 5 ) :
train( epoch)
test( )
epoch:0,batch_idx:299,loss:2.288219335079193
epoch:0,batch_idx:599,loss:2.237070083816846
epoch:0,batch_idx:899,loss:1.9507552921772002
正确率:%78.45
epoch:1,batch_idx:299,loss:0.5997616300483545
epoch:1,batch_idx:599,loss:0.5263905554513136
epoch:1,batch_idx:899,loss:0.47859397581881946
正确率:%90.09
epoch:2,batch_idx:299,loss:0.338019950290521
epoch:2,batch_idx:599,loss:0.31978671551992494
epoch:2,batch_idx:899,loss:0.3034892007046276
正确率:%92.64
epoch:3,batch_idx:299,loss:0.24073752822975317
epoch:3,batch_idx:599,loss:0.229553019789358
epoch:3,batch_idx:899,loss:0.2237504247865743
正确率:%93.77
epoch:4,batch_idx:299,loss:0.18352980084717274
epoch:4,batch_idx:599,loss:0.176467257378002
epoch:4,batch_idx:899,loss:0.17287401129802069
正确率:%95.43