用fastai ResNet50训练CIFAR10,85%准确度

版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

 

Fastai是在pytorch上封装的深度学习框架,效果出众,以下是训练CIFAR10的过程。

导入库

from fastai import *
from fastai.vision import *
from fastai.callbacks import CSVLogger, SaveModelCallback

验证集上训练结果计算和显示

def show_result(learn):
    # 得到验证集上的准确度
    probs, val_labels = learn.get_preds(ds_type=DatasetType.Valid)
    print('Accuracy', accuracy(probs, val_labels)),
    print('Error Rate', error_rate(probs, val_labels))

训练结果混淆矩阵及预测错误最多的类型显示

def show_matrix(learn):
# 画训练结果的混合矩阵
interp = ClassificationInterpretation.from_learner(learn)
interp.confusion_matrix()
interp.plot_confusion_matrix(dpi=120)

# 显示判断错误最多的类型,min_val指定错误次数,默认1
# 打印顺序为actual, predicted, number of occurrences.
interp.most_confused(min_val=5)

# 模型预测最困难的9个样本显示
# 显示顺序为预测值、实际值、损失值、预测对的概率
interp.plot_top_losses(9, figsize=(10, 10))

下载数据集,因调用linux的tar进行解压,在windows下会出错,可手动解压,解压后的目录:

 

# 下载数据集
untar_data(URLs.CIFAR)

# 训练数据目录
path = Path(r'C:\Users\Administrator\.fastai\data\cifar10')

定义数据及数据在线增强方式

# 数据在线增强方式定义
tfms = get_transforms(do_flip=False)

data = (ImageList.from_folder(path)  # Where to find the data? -> in path and its subfolders
        .split_by_rand_pct()  # How to split in train/valid? -> use the folders
        .label_from_folder()  # How to label? -> depending on the folder of the filenames
        .add_test_folder()  # Optionally add a test set (here default name is test)
        .transform(tfms, size=(32, 32))  # Data augmentation? -> use tfms with a size of 164
        .databunch(bs=128)  # Finally? -> use the defaults for conversion to ImageDataBunch
        .normalize(imagenet_stats))

 查看数据

# 查看数据信息
data.classes, data.c, data
(['airplane',
  'automobile',
  'bird',
  'cat',
  'deer',
  'dog',
  'frog',
  'horse',
  'ship',
  'truck'],
 10,
 ImageDataBunch;
 
 Train: LabelList (39072 items)
 x: ImageList
 Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
 y: CategoryList
 airplane,airplane,airplane,airplane,airplane
 Path: C:\Users\Administrator\.fastai\data\cifar10;
 
 Valid: LabelList (9767 items)
 x: ImageList
 Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
 y: CategoryList
 airplane,deer,deer,deer,automobile
 Path: C:\Users\Administrator\.fastai\data\cifar10;
 
 Test: LabelList (10000 items)
 x: ImageList
 Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
 y: EmptyLabelList
 ,,,,
 Path: C:\Users\Administrator\.fastai\data\cifar10)

 创建训练器

# 创建learn
learn = cnn_learner(data, models.resnet50, metrics=[accuracy, error_rate], callback_fns=[ShowGraph, SaveModelCallback])

 第一阶段训练

# 最佳学习率寻找
learn.lr_find(end_lr=1)
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

# 画出学习率寻找曲线,给出建议学习率
learn.recorder.plot(suggestion=True)

 

# 根据学习率曲线得到max_lr,开始训练
learn.fit_one_cycle(cyc_len=15, max_lr=1.78e-2)

 

epochtrain_lossvalid_lossaccuracyerror_ratetime
01.0741620.8821360.7092250.29077501:17
10.8241120.7661630.7404530.25954701:16
20.8110900.9383450.7077920.29220801:16
30.7994500.7906650.7337970.26620301:16
40.7632001.3647580.7526360.24736401:18
50.6934900.6835590.7769020.22309801:16
60.6736210.6117990.8006550.19934501:16
70.6651260.6307150.7961500.20385001:16
80.6121870.8745670.8261490.17385101:16
90.5636340.7851890.8207230.17927701:16
100.5155401.2862710.8298350.17016501:21
110.4859590.5244550.8406880.15931201:16
120.4444170.7599440.8427360.15726401:17
130.4198380.8304820.8455000.15450001:17
140.4210950.5506060.8447830.15521701:16
 
 
Better model found at epoch 0 with val_loss value: 0.8821364045143127.
Better model found at epoch 1 with val_loss value: 0.7661632299423218.
Better model found at epoch 5 with val_loss value: 0.6835585832595825.
Better model found at epoch 6 with val_loss value: 0.6117991805076599.
Better model found at epoch 11 with val_loss value: 0.5244545340538025.

  训练结果

# 计算和显示训练结果
show_result(learn)

 Accuracy tensor(0.8407)

Error Rate tensor(0.1593)

 

# 保存训练模型
learn.save('stg1')

 

 

 第二阶段训练

 

learn.load('stg1')
learn.unfreeze()
learn.lr_find(end_lr=1)

 LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

learn.recorder.plot(suggestion=True)

learn.fit_one_cycle(15, slice(1e-6, 5e-5))
epochtrain_lossvalid_lossaccuracyerror_ratetime
00.4445690.5218280.8403810.15961901:26
10.4270620.5134340.8404830.15951701:27
20.4303440.5148670.8465240.15347601:23
30.4214800.5505270.8452950.15470501:23
40.4101700.5069490.8478550.15214501:23
50.4021500.5420910.8491860.15081401:26
60.3876390.4911200.8509270.14907301:27
70.3730220.5115800.8521550.14784501:28
80.3754970.5054930.8541010.14589901:28
90.3554660.5854250.8524620.14753801:28
100.3553270.5064020.8555340.14446601:28
110.3412080.4985020.8559440.14405701:29
120.3470570.5491460.8517460.14825401:28
130.3451850.5339620.8521550.14784501:28
140.3343360.5042310.8554320.14456801:29
 
 
Better model found at epoch 0 with val_loss value: 0.5218283534049988.
Better model found at epoch 1 with val_loss value: 0.5134344696998596.
Better model found at epoch 4 with val_loss value: 0.5069490671157837.
Better model found at epoch 6 with val_loss value: 0.491120308637619.

训练结果
# 计算和显示训练结果
show_result(learn)
Accuracy tensor(0.8509)
Error Rate tensor(0.1491)

保存模型
learn.save('stg2')

  

# 画训练结果的混合矩阵
interp = ClassificationInterpretation.from_learner(learn)
interp.confusion_matrix()
interp.plot_confusion_matrix(dpi=120)

 

 

显示预测错误次数最多的类型,错误次数大于5,输出顺序为actual, predicted, number of occurrences.

interp.most_confused(5)
[('bird', 'frog', 86),
 ('truck', 'automobile', 71),
 ('deer', 'frog', 66),
 ('dog', 'bird', 59),
 ('airplane', 'ship', 57),
 ('bird', 'airplane', 54),
 ('dog', 'frog', 54),
 ('bird', 'deer', 53),
 ('dog', 'deer', 50),
 ('cat', 'dog', 47),
 ('deer', 'bird', 47),
 ('automobile', 'truck', 45),
 ('ship', 'airplane', 45),
 ('cat', 'frog', 44),
 ('bird', 'dog', 37),
 ('ship', 'automobile', 34),
 ('ship', 'truck', 32),
 ('airplane', 'bird', 31),
 ('deer', 'dog', 26),
 ('frog', 'bird', 25),
 ('dog', 'cat', 24),
 ('dog', 'horse', 24),
 ('airplane', 'automobile', 23),
 ('horse', 'deer', 23),
 ('airplane', 'truck', 22),
 ('airplane', 'deer', 20),
 ('frog', 'deer', 17),
 ('cat', 'deer', 16),
 ('horse', 'dog', 14),
 ('automobile', 'ship', 13),
 ('deer', 'horse', 13),
 ('truck', 'ship', 13),
 ('bird', 'ship', 12),
 ('cat', 'bird', 12),
 ('deer', 'airplane', 12),
 ('dog', 'truck', 12),
 ('truck', 'airplane', 12),
 ('frog', 'dog', 11),
 ('airplane', 'frog', 10),
 ('deer', 'ship', 10),
 ('dog', 'airplane', 9),
 ('frog', 'automobile', 8),
 ('horse', 'frog', 8),
 ('ship', 'bird', 8),
 ('cat', 'truck', 7),
 ('horse', 'airplane', 7),
 ('horse', 'bird', 7),
 ('ship', 'deer', 7),
 ('dog', 'automobile', 6),
 ('truck', 'frog', 6),
 ('automobile', 'frog', 5),
 ('bird', 'cat', 5),
 ('bird', 'truck', 5),
 ('cat', 'ship', 5),
 ('dog', 'ship', 5),
 ('frog', 'airplane', 5)]

预测最困难的9个样本

# 模型预测最困难的9个样本显示
# 显示顺序为预测值、实际值、损失值、预测对的概率
interp.plot_top_losses(9, figsize=(10, 10))

 

转载于:https://www.cnblogs.com/zhengbiqing/p/10923342.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值