模型的保存
import numpy as np
import torch
from torch import nn, optim
from torch. autograd import Variable
from torchvision import datasets, transforms
from torch. utils. data import DataLoader
train_data = datasets. MNIST( root= "./" ,
train = True ,
transform= transforms. ToTensor( ) ,
download = True
)
test_data = datasets. MNIST( root= "./" ,
train = False ,
transform= transforms. ToTensor( ) ,
download = True
)
batch_size = 64
train_loader = DataLoader( dataset= train_data, batch_size= batch_size, shuffle= True )
test_loader = DataLoader( dataset= test_data, batch_size= batch_size, shuffle= True )
for i, data in enumerate ( train_loader) :
inputs, labels = data
print ( inputs. shape)
print ( labels. shape)
break
class LSTM ( nn. Module) :
def __init__ ( self) :
super ( LSTM, self) . __init__( )
self. lstm = torch. nn. LSTM(
input_size = 28 ,
hidden_size = 64 ,
num_layers = 1 ,
batch_first = True
)
self. out = torch. nn. Linear( in_features= 64 , out_features= 10 )
self. softmax = torch. nn. Softmax( dim= 1 )
def forward ( self, x) :
x = x. view( - 1 , 28 , 28 )
output, ( h_n, c_n) = self. lstm( x)
output_in_last_timestep = h_n[ - 1 , : , : ]
x = self. out( output_in_last_timestep)
x = self. softmax( x)
return x
model = LSTM( )
mse_loss = nn. CrossEntropyLoss( )
optimizer = optim. Adam( model. parameters( ) , lr= 0.001 )
def train ( ) :
model. train( )
for i, data in enumerate ( train_loader) :
inputs, labels = data
out = model( inputs)
loss = mse_loss( out, labels)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
def test ( ) :
model. eval ( )
correct = 0
for i, data in enumerate ( test_loader) :
inputs, labels = data
out = model( inputs)
_, predicted = torch. max ( out, 1 )
correct += ( predicted== labels) . sum ( )
print ( "Test acc:{0}" . format ( correct. item( ) / len ( test_data) ) )
correct = 0
for i, data in enumerate ( train_loader) :
inputs, labels = data
out = model( inputs)
_, predicted = torch. max ( out, 1 )
correct += ( predicted== labels) . sum ( )
print ( "Train acc:{0}" . format ( correct. item( ) / len ( train_data) ) )
for epoch in range ( 5 ) :
print ( "epoch:" , epoch)
train( )
test( )
torch. save( model. state_dict( ) , "./my_model.pth" )
模型加载
import numpy as np
import torch
from torch import nn, optim
from torch. autograd import Variable
from torchvision import datasets, transforms
from torch. utils. data import DataLoader
train_data = datasets. MNIST( root= "./" ,
train = True ,
transform= transforms. ToTensor( ) ,
download = True
)
test_data = datasets. MNIST( root= "./" ,
train = False ,
transform= transforms. ToTensor( ) ,
download = True
)
batch_size = 64
train_loader = DataLoader( dataset= train_data, batch_size= batch_size, shuffle= True )
test_loader = DataLoader( dataset= test_data, batch_size= batch_size, shuffle= True )
for i, data in enumerate ( train_loader) :
inputs, labels = data
print ( inputs. shape)
print ( labels. shape)
break
class Net ( nn. Module) :
def __init__ ( self) :
super ( Net, self) . __init__( )
self. fc1 = nn. Linear( 784 , 10 )
self. softmax = nn. Softmax( dim= 1 )
def forward ( self, x) :
x = x. view( x. size( ) [ 0 ] , - 1 )
x = self. fc1( x)
x = self. softmax( x)
return x
model = Net( )
model. load_state_dict( torch. load( "./my_model.pth" ) )
mse_loss = nn. MSELoss( )
optimizer = optim. Adam( model. parameters( ) , lr= 0.001 )
def train ( ) :
for i, data in enumerate ( train_loader) :
inputs, labels = data
out = model( inputs)
labels = labels. reshape( - 1 , 1 )
one_hot = torch. zeros( inputs. shape[ 0 ] , 10 ) . scatter( 1 , labels, 1 )
loss = mse_loss( out, one_hot)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
def test ( ) :
correct = 0
for i, data in enumerate ( test_loader) :
inputs, labels = data
out = model( inputs)
_, predicted = torch. max ( out, 1 )
correct += ( predicted== labels) . sum ( )
print ( "Test acc:{0}" . format ( correct. item( ) / len ( test_data) ) )
correct = 0
for i, data in enumerate ( train_loader) :
inputs, labels = data
out = model( inputs)
_, predicted = torch. max ( out, 1 )
correct += ( predicted== labels) . sum ( )
print ( "Train acc:{0}" . format ( correct. item( ) / len ( train_data) ) )
test( )