import torchvision
import torch
import torch. nn as nn
import hiddenlayer as hl
from torchvision import transforms
import torch. utils. data as Data
import torch. optim as optim
train_data = torchvision. datasets. MNIST( root= "./data/MNIST" , transform = transforms. ToTensor( ) , download = True )
train_loader = Data. DataLoader( dataset= train_data, batch_size= 64 , shuffle= True , num_workers= 0 )
test_data = torchvision. datasets. MNIST( root= "./data/MNIST" , transform = transforms. ToTensor( ) , download = True )
test_loader = Data. DataLoader( dataset= test_data, batch_size= 64 , shuffle= True , num_workers= 0 )
train_data_size = len ( train_data)
test_data_size = len ( test_data)
print ( "train_data_size = {}" . format ( train_data_size) )
print ( "test_data_size = {}" . format ( test_data_size) )
train_data_size = 60000
test_data_size = 60000
class MyRNN ( nn. Module) :
def __init__ ( self, input_dim, hidden_dim, layer_dim, output_dim) :
super ( MyRNN, self) . __init__( )
self. hidden_dim = hidden_dim
self. layer_dim = layer_dim
self. rnn = nn. RNN( input_dim, hidden_dim, num_layers= layer_dim,
batch_first= True , nonlinearity= 'relu' )
self. fc1 = nn. Linear( hidden_dim, output_dim)
def forward ( self, x) :
out, h_n = self. rnn( x)
out = self. fc1( out[ : , - 1 , : ] )
return out
input_dim = 28
hidden_dim = 128
layer_dim = 1
output_dim = 10
MyRNNimc = MyRNN( input_dim, hidden_dim, layer_dim, output_dim)
optimizer = torch. optim. RMSprop( MyRNNimc. parameters( ) , lr= 0.0003 )
criterion = nn. CrossEntropyLoss( )
print ( MyRNNimc)
MyRNN(
(rnn): RNN(28, 128, batch_first=True)
(fc1): Linear(in_features=128, out_features=10, bias=True)
)
hl_graph= hl. build_graph( MyRNNimc, torch. zeros( size= [ 1 , 28 , 28 ] ) )
hl_graph = hl. build_graph( MyRNNimc, input_tensor)
hl_graph. theme = hl. graph. THEMES[ "blue" ] . copy( )
hl_graph
optimizer = torch. optim. RMSprop( MyRNNimc. parameters( ) , lr= 0.0003 )
criterion = nn. CrossEntropyLoss( )
train_loss_all = [ ]
train_acc_all = [ ]
test_loss_all = [ ]
test_acc_all = [ ]
num_epochs = 30
for epoch in range ( num_epochs) :
print ( 'Epoch {}/{}' . format ( epoch, num_epochs - 1 ) )
MyRNNimc. train( )
train_loss = 0.0
corrects = 0
train_num = 0
for step, ( b_x, b_y) in enumerate ( train_loader) :
xdata = b_x. view( - 1 , 28 , 28 )
output = MyRNNimc( xdata)
pred_lab = torch. argmax( output, 1 )
loss = criterion( output, b_y)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
train_loss += loss. item( ) * b_x. size( 0 )
corrects += torch. sum ( pred_lab == b_y) . item( )
train_num += b_x. size( 0 )
train_loss_all. append( train_loss / train_num)
train_acc_all. append( corrects / train_num)
print ( 'Train Loss: {:.4f}, Train Acc: {:.4f}' . format ( train_loss_all[ - 1 ] , train_acc_all[ - 1 ] ) )
MyRNNimc. eval ( )
test_loss = 0.0
corrects = 0
test_num = 0
for step, ( b_x, b_y) in enumerate ( test_loader) :
xdata = b_x. view( - 1 , 28 , 28 )
output = MyRNNimc( xdata)
pred_lab = torch. argmax( output, 1 )
loss = criterion( output, b_y)
test_loss += loss. item( ) * b_x. size( 0 )
corrects += torch. sum ( pred_lab == b_y) . item( )
test_num += b_x. size( 0 )
test_loss_all. append( test_loss / test_num)
test_acc_all. append( corrects / test_num)
print ( 'Test Loss: {:.4f}, Test Acc: {:.4f}' . format ( test_loss_all[ - 1 ] , test_acc_all[ - 1 ] ) )
Epoch 0/29
Train Loss: 0.5452, Train Acc: 0.8335
Test Loss: 0.4664, Test Acc: 0.8524
Epoch 1/29
Train Loss: 0.3912, Train Acc: 0.8826
Test Loss: 0.3929, Test Acc: 0.8778
Epoch 2/29
Train Loss: 0.3026, Train Acc: 0.9103
Test Loss: 0.3048, Test Acc: 0.9089
Epoch 28/29
Train Loss: 0.0497, Train Acc: 0.9850
Test Loss: 0.0605, Test Acc: 0.9821
Epoch 29/29
Train Loss: 0.0495, Train Acc: 0.9850
Test Loss: 0.0437, Test Acc: 0.9862
import matplotlib. pyplot as plt
plt. figure( figsize= ( 14 , 5 ) )
plt. subplot( 1 , 2 , 1 )
plt. plot( train_loss_all, "ro-" , label= "Train Loss" )
plt. plot( test_loss_all, "bs-" , label= "Test Loss" )
plt. legend( )
plt. xlabel( "Epoch" )
plt. ylabel( "Loss" )
plt. subplot( 1 , 2 , 2 )
plt. plot( train_acc_all, "ro-" , label= "Train Acc" )
plt. plot( test_acc_all, "bs-" , label= "Test Acc" )
plt. xlabel( "Epoch" )
plt. ylabel( "Accuracy" )
plt. legend( )
plt. show( )