一:导入所用到的模块
import torch
import torch. nn as nn
import torch. nn. functional as F
import torchvision
import torch. optim as optim
import torchvision. transforms as transforms
import matplotlib. pyplot as plt
import numpy as np
二:定义LeNet模型
class LeNet ( nn. Module) :
def __init__ ( self) :
super ( LeNet, self) . __init__( )
self. conv1 = nn. Conv2d( 3 , 16 , 5 )
self. pool1 = nn. MaxPool2d( 2 , 2 )
self. conv2 = nn. Conv2d( 16 , 32 , 5 )
self. pool2 = nn. MaxPool2d( 2 , 2 )
self. fc1 = nn. Linear( 32 * 5 * 5 , 120 )
self. fc2 = nn. Linear( 120 , 84 )
self. fc3 = nn. Linear( 84 , 10 )
def forward ( self, x) :
x = F. relu( self. conv1( x) )
x = self. pool1( x)
x = F. relu( self. conv2( x) )
x = self. pool2( x)
x = x. view( - 1 , 32 * 5 * 5 )
x = F. relu( self. fc1( x) )
x = F. relu( self. fc2( x) )
x = self. fc3( x)
return x
三:下载CIFAR10数据集
ToTensor 将数据集转换为 C * H * W 并且归一化 Normalize 数据进行标准化
transform = transforms. Compose( [
transforms. ToTensor( )
, transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) )
] )
trainset = torchvision. datasets. CIFAR10( root= './data'
, train= True
, download= False
, transform= transform
)
trainloader = torch. utils. data. DataLoader( trainset
, batch_size= 36
, shuffle= True
, num_workers= 0 )
testset = torchvision. datasets. CIFAR10( root= './data'
, train= False
, download= False
, transform= transform
)
testloader = torch. utils. data. DataLoader( testset
, batch_size= 10000
, shuffle= False
, num_workers= 0 )
iter 将可迭代的对象转化为一个迭代器 next () 取出迭代器中测试图像以及其标签值 低版本为next()
test_data_iter = iter ( testloader)
test_image, test_label = test_data_iter. __next__( )
classes = ( 'plane' , 'car' , 'bird' , 'cat' , 'deer'
, 'dog' , 'frog' , 'horse' , 'ship' , 'truck' )
test_image[ 1 ] . size( )
test_label. size( )
test_image. size( )
train_data_iter = iter ( trainloader)
train_image, train_label = train_data_iter. __next__( )
print ( train_image. size( ) )
train_label. size( )
查看前四张图片
def imshow ( img) :
img = img / 2 + 0.5
npimg = img. numpy( )
plt. imshow( np. transpose( npimg, ( 1 , 2 , 0 ) ) )
plt. show( )
print ( ' ' . join( '%5s' % classes[ test_label[ j] ] for j in range ( 4 ) ) )
imshow( torchvision. utils. make_grid( test_image) )
四:实例化模型
net = LeNet( )
loss_function = nn. CrossEntropyLoss( )
optimizer = optim. Adam( net. parameters( ) , lr= 0.001 )
五:训练
for epoch in range ( 5 ) :
running_loss = 0.0
for step, data in enumerate ( trainloader, start= 0 ) :
inputs, labels = data
optimizer. zero_grad( )
outputs = net( inputs)
loss = loss_function( outputs, labels)
loss. backward( )
optimizer. step( )
running_loss += loss. item( )
if step % 500 == 499 :
with torch. no_grad( ) :
outputs = net( test_image)
predict_y = torch. max ( outputs, dim= 1 ) [ 1 ]
accuracy = ( predict_y == test_label) . sum ( ) . item( ) / test_label. size( 0 )
print ( '[%d,%5d] train_loss: %.3f test_accuracy: %.3f' %
( epoch + 1 , step + 1 , running_loss / 500 , accuracy) )
running_loss = 0.0
print ( 'Finished Training' )
六:保存训练的模型的参数
save_path = './Lenet.pth'
torch. save( net. state_dict( ) , save_path)
七:迁移学习
from PIL import Image
transform = transforms. Compose( [
transforms. Resize( ( 32 , 32 ) )
, transforms. ToTensor( )
, transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) )
] )
pr_net = LeNet( )
pr_net. load_state_dict( torch. load( 'Lenet.pth' ) )
im = Image. open ( '1.jpg' )
im = transform( im)
im = torch. unsqueeze( im, dim= 0 )
with torch. no_grad( ) :
output = pr_net( im)
predict = torch. max ( output, dim= 1 ) [ 1 ] . data. numpy( )
print ( classes[ int ( predict) ] )