一、导入包库
import torch
import torchvision as tv
from torch import nn
from torch. utils. data import DataLoader
import torchvision. transforms as T
import torch. nn. functional as F
二、超参数
BATCH_SIZE = 128
DEVICE = torch. device( 'cuda' if torch. cuda. is_available( ) else 'cpu' )
三、数据预处理
transform = T. ToTensor( )
train_data = tv. datasets. MNIST(
root= './data' ,
train= True ,
download= False ,
transform= transform
)
train_loader = torch. utils. data. DataLoader(
train_data,
batch_size= BATCH_SIZE,
shuffle= False
)
test_data = tv. datasets. MNIST(
root= './data' ,
train= True ,
download= False ,
transform= transform
)
test_loader = torch. utils. data. DataLoader(
test_data,
batch_size= BATCH_SIZE,
shuffle= False
)
四、构建网络
class Net ( nn. Module) :
def __init__ ( self) :
super ( Net, self) . __init__( )
self. conv1 = nn. Conv2d( 1 , 20 , 5 , 1 )
self. conv2 = nn. Conv2d( 20 , 50 , 5 , 1 )
self. fc1 = nn. Linear( 4 * 4 * 50 , 500 )
self. fc2 = nn. Linear( 500 , 10 )
def forward ( self, x) :
x = F. relu( self. conv1( x) )
x = F. max_pool2d( x, 2 , 2 )
x = F. relu( self. conv2( x) )
x = F. max_pool2d( x, 2 , 2 )
x = x. view( - 1 , 4 * 4 * 50 )
x = F. relu( self. fc1( x) )
x = self. fc2( x)
return F. log_softmax( x, dim= 1 )
五、训练
model = Net( )
model = model. to( DEVICE)
optimirzer = torch. optim. Adam( model. parameters( ) , lr= 1e - 3 )
loss_func = nn. CrossEntropyLoss( )
for epoch in range ( 10 ) :
for index, ( x, label) in enumerate ( train_loader) :
x, label = x. to( DEVICE) , label. to( DEVICE)
out = model( x)
loss = loss_func( out, label)
optimirzer. zero_grad( )
loss. backward( )
optimirzer. step( )
if ( index + 1 ) % 100 == 0 or ( index + 1 ) == len ( train_loader) :
print ( 'TRAIN' , 'epoch' , epoch, 'batch index' , index + 1 , 'loss' , float ( loss) )
corrent = count = 0
for index, ( x, label) in enumerate ( test_loader) :
x, label = x. to( DEVICE) , label. to( DEVICE)
out = model( x)
loss = loss_func( out, label)
_, prtdict = torch. max ( out, 1 )
count += x. shape[ 0 ]
corrent += ( prtdict == label) . sum ( )
if ( index + 1 ) % 100 == 0 or ( index + 1 ) == len ( test_loader) :
print ( 'TRAIN' , 'epoch' , epoch, 'batch index' , index + 1 , 'loss' , float ( loss) , 'acc' , corrent * 1.0 / count)