tensorflow 自定义训练——eager模式,附带数据集制作

在前边的博客中我使用了tensorflow中kears模块来搭建神经网络,然后使用集成好的方法来进行训练,如model.fit(),从最简洁、最快速的方面来实现了神经网络,但是集成好的api灵活性不够,而且训练耗时比eager模式要多,所以在这里我对自定义训练做一个简单介绍,使用的数据集来源kaggle,链接:https://www.kaggle.com/alxmamaev/flowers-recognition,如果不能下载用这个百度网盘链接:https://pan.baidu.com/s/1bg2981NUwwNZmkcz0w5WLw,提取码:fp34。完整代码在我的github,觉得我写的不错的可以给个star,谢谢。源码链接:https://github.com/JohnLeek/Tensorflow-study,文件名为day7_flowers_tf_data.py

一、数据集制作

1.1 数据集基本情况简介

数据的来源是kaggle的一个花卉分类竞赛,里边总共包含五种花卉分别是,daisy(雏菊)、dandelion(蒲公英)、rose(玫瑰)、sunflower(向日葵)、tuilp(郁金香),总共是4322张图片,但是里边每个花卉图片数量不怎么一样,有的多有的少。

1.2 数据制作

这里的数据集制作我们采用tensorflow2.0中提供的data模块,根据我自己测试的结果用tf提供的模块制作数据集是比较高效的,下一篇博客中我会介绍两到三种数据集制作方法。

首先我们导入所需要的所有包

import glob
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D,BatchNormalization,MaxPool2D,Dropout,Dense,Flatten
import numpy as np

然后将数据集拷贝到项目工程下

忽略其他文件夹(这是我平时写着玩的一些模型,还有数据集,大家看flowers文件夹就好了)

文件路径

path = './flowers/*/*.jpg'

使用glob包来获得所有图片路径

all_image_path = glob.glob(path)

部分文件路径展示(这里为了方便展示用的jupyter,但是项目用的是pycharm)

到这里文件路径就获取完毕接下来就是将花卉名称转化为标签,我这里的对应规则是这样的daisy(雏菊)->0、dandelion(蒲公英)->1、rose(玫瑰)->2、sunflower(向日葵)->3、tuilp(郁金香)->4。建立一个列表来存储。

all_image_label = []
for p in all_image_path:
    if p.split('\\')[1] == 'daisy':
        all_image_label.append(0)
    if p.split('\\')[1] == 'dandelion':
        all_image_label.append(1)
    if p.split('\\')[1] == 'rose':
        all_image_label.append(2)
    if p.split('\\')[1] == 'sunflower':
        all_image_label.append(3)
    if p.split('\\')[1] == 'tulip':
        all_image_label.append(4)

好了到这里标签页准备好了,接下来就是对现有的数据进行乱序然后对训练集和测试集进行划分,我这里使用的比例是训练集:测试集=8:2。

数据集乱序

np.random.seed(5000)
np.random.shuffle(all_image_path) 
np.random.seed(5000) 
np.random.shuffle(all_image_label)

数据集的划分

image_count = len(all_image_path)
flag = int(len(all_image_path)*0.8)

train_image_path = all_image_path[:flag]
test_image_path = all_image_path[-(image_count-flag):]

train_image_label = all_image_label[:flag]
test_image_label = all_image_label[-(image_count-flag):]

编写图片加载预处理函数load_preprogress_image(),这里因为数据集的量不是很大,我对图片进行了,反转,亮度等方面的调节,扩大了 数据集,避免训练的时候出现过拟合的情况。

def load_preprogress_image(image_path,label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.resize(image,[100,100])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, 0.5)
    image = tf.image.random_contrast(image, 0, 1)
    image = tf.cast(image,tf.float32)
    image = image / 255.
    label = tf.reshape(label,[1]

设置CPU线程(为了加快生成数据集),以及batch

Batch_Size = 32 AUTOTUNE = tf.data.experimental.AUTOTUNE#tf会根据你的cpu支持的线程数量来加速数据集生成过程

生成训练集

train_image_dataset = tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
train_image_dataset = train_image_dataset.map(load_preprogress_image,num_parallel_calls=AUTOTUNE)
train_image_dataset = train_image_dataset.shuffle(flag).batch(Batch_Size)
train_image_dataset = train_image_dataset.prefetch(AUTOTUNE)

生成测试集

test_image_dataset = tf.data.Dataset.from_tensor_slices((test_image_path,test_image_label))
test_image_dataset = test_image_dataset.map(load_preprogress_image,num_parallel_calls=AUTOTUNE)
test_image_dataset = test_image_dataset.shuffle(image_count-flag).batch(Batch_Size)
test_image_dataset = test_image_dataset.prefetch(AUTOTUNE)

二、神经网络搭建

这里使用的Sequential来搭建CNN,因为不是重点我就直接跳过了

model = Sequential([
    Conv2D(16,(3,3),padding="same",activation="relu"),
    BatchNormalization(),
    MaxPool2D(2,2),
    Dropout(0.2),

    Conv2D(32,(3,3),padding="same",activation="relu"),
    BatchNormalization(),
    MaxPool2D(2,2),
    Dropout(0.2),

    Conv2D(64, (3, 3), padding="same", activation="relu"),
    BatchNormalization(),
    MaxPool2D(2, 2),
    Dropout(0.2),

    Conv2D(128,(3,3),padding="same",activation="relu"),
    BatchNormalization(),
    MaxPool2D(2,2),
    Dropout(0.2),

    Conv2D(256, (3, 3), padding="same", activation="relu"),
    BatchNormalization(),
    MaxPool2D(2, 2),
    Dropout(0.2),

    Flatten(),
    Dense(512,activation="relu"),
    Dropout(0.2),
    Dense(128, activation="relu"),
    Dropout(0.2),
    Dense(5,activation="softmax")
])

三、神经网络训练

重点来了,以前我是使用keras中的fit函数来实现对模型的训练,今天我们采用自定义训练,有很多地方是很不一样的需要注意下,要是学会了tf的自定义训练,其实看pytorch也就很轻松了,至少对我来说是这样的。

1、定义优化器和损失函数

损失函数我使用,交叉熵损失函数,这里因为标签是0,1,2,3,4,不是独热码所以选择SparseCategoricalCrossentropy,优化器使用Adam,参数全都默认。

loss_object = tf.keras.losses.SparseCategoricalCrossentropy() 
optimizer = tf.keras.optimizers.Adam()

2、设置损失函数还有准确率计算方法

这里使用了keras模块中的Mean方法,来统计每个batch结束后损失函数的均值,用SparseCategoricalAccuracy,来确定正确率。

train_loss = tf.keras.metrics.Mean('train_loss')
train_accracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean('test_loss')
test_accracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

3、训练函数

可以看到这里用了@tf.function这个注解,根据官网给出的文档,使用该注解可以将下边声明的函数转化为图节点加快训练过程。我自己做了下对比,使用@tf.function,确实要比不使用要快,比使用keras中fit方法训练也快很多。

手动训练的过程需要我们自己过前向传播还有计算损失函数,训练过程,前向传播过程使用梯度下降根据我们模型的预测prediction,和原有数据标签labels进行计算和更新,梯度下降中的参数为loss,和网络可训练参数model.trainable_variables,最后送至优化器,优化器参数为梯度gradients 和网络可训练参数model.trainable_variables。最后将一个batch中的loss和accracy进行计算(train_loss,train_accracy)

@tf.function
def train_step(images,labels):
    with tf.GradientTape() as tape:
       prediction = model(images)
       loss = loss_object(labels,prediction)
    gradients = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    train_loss(loss)
    train_accracy(labels,prediction)

4、测试函数

测试部分我们只需要过前向传播,根据训练好的模型,来判断网络拟合能力,并不做参数更新。

@tf.function
def test_step(images,lables):
    prediction = model(images)
    t_loss = loss_object(lables,prediction)

    test_loss(t_loss)
    test_accracy(lables,prediction)

5、main()函数

main()函数里边要注意的第一个点就是在每个循环开始我们要将梯度值还有准确率归零。

def main():
    EPOCHS = 50
    
    for epoch in range(EPOCHS):
        # 下个循环开始时指标归零
        train_loss.reset_states()
        train_accracy.reset_states()
        test_loss.reset_states()
        test_accracy.reset_states()
        for images, labels in train_image_dataset:
            train_step(images, labels)

        for test_images, test_labels in test_image_dataset:
            test_step(test_images, test_labels)
          
        template = 'Epoch:{},Loss:{:.4f},Accuracy:{:.4f},Test Loss:{:.4f},Test Accuracy:{:.4f}'
        print(template.format(epoch + 1, train_loss.result(), train_accracy.result() * 100, test_loss.result(),
                              test_accracy.result() * 100))

6、调用

if __name__ == '__main__':     main()

四、结果展示

Epoch:1,Loss:1.4798,Accuracy:33.3237,Test Loss:1.3483,Test Accuracy:38.4971
Epoch:2,Loss:1.3679,Accuracy:38.8776,Test Loss:1.3610,Test Accuracy:41.2717
Epoch:3,Loss:1.3311,Accuracy:40.5554,Test Loss:1.2352,Test Accuracy:45.6647
Epoch:4,Loss:1.2949,Accuracy:43.6795,Test Loss:1.3696,Test Accuracy:38.9595
Epoch:5,Loss:1.2965,Accuracy:45.2994,Test Loss:1.1869,Test Accuracy:47.3988
Epoch:6,Loss:1.2407,Accuracy:47.0929,Test Loss:1.1343,Test Accuracy:53.8728
Epoch:7,Loss:1.1775,Accuracy:50.9690,Test Loss:1.1222,Test Accuracy:52.7168
Epoch:8,Loss:1.1375,Accuracy:53.1096,Test Loss:1.1582,Test Accuracy:56.1850
Epoch:9,Loss:1.1055,Accuracy:55.2502,Test Loss:1.1028,Test Accuracy:53.2948
Epoch:10,Loss:1.0280,Accuracy:58.4900,Test Loss:0.9991,Test Accuracy:58.6127
Epoch:11,Loss:1.0030,Accuracy:59.0107,Test Loss:1.0338,Test Accuracy:60.2312
Epoch:12,Loss:1.0861,Accuracy:55.1345,Test Loss:0.9566,Test Accuracy:61.9653
Epoch:13,Loss:0.9947,Accuracy:61.2381,Test Loss:1.3753,Test Accuracy:43.4682
Epoch:14,Loss:1.1058,Accuracy:55.1345,Test Loss:0.9431,Test Accuracy:60.6936
Epoch:15,Loss:0.9891,Accuracy:60.8042,Test Loss:0.9049,Test Accuracy:64.1618
Epoch:16,Loss:0.9298,Accuracy:62.5398,Test Loss:0.9950,Test Accuracy:59.6532
Epoch:17,Loss:0.8813,Accuracy:64.3043,Test Loss:0.9478,Test Accuracy:61.6185
Epoch:18,Loss:0.8744,Accuracy:64.9118,Test Loss:0.9666,Test Accuracy:63.1214
Epoch:19,Loss:0.8952,Accuracy:64.0729,Test Loss:0.9820,Test Accuracy:62.5434
Epoch:20,Loss:0.9128,Accuracy:63.5811,Test Loss:0.9640,Test Accuracy:62.4277
Epoch:21,Loss:0.8939,Accuracy:64.0729,Test Loss:0.8645,Test Accuracy:66.7052
Epoch:22,Loss:0.8058,Accuracy:68.3541,Test Loss:0.8383,Test Accuracy:66.7052
Epoch:23,Loss:0.8182,Accuracy:68.0648,Test Loss:0.9324,Test Accuracy:63.4682
Epoch:24,Loss:0.8720,Accuracy:65.0853,Test Loss:0.9143,Test Accuracy:66.3584
Epoch:25,Loss:0.7710,Accuracy:68.7880,Test Loss:0.8930,Test Accuracy:65.6647
Epoch:26,Loss:0.7632,Accuracy:69.5979,Test Loss:0.9349,Test Accuracy:67.5145
Epoch:27,Loss:0.7457,Accuracy:70.0897,Test Loss:0.9430,Test Accuracy:63.5838
Epoch:28,Loss:0.7242,Accuracy:70.6682,Test Loss:0.8867,Test Accuracy:66.1272
Epoch:29,Loss:0.7343,Accuracy:71.5649,Test Loss:0.8693,Test Accuracy:65.6647
Epoch:30,Loss:0.7664,Accuracy:69.6847,Test Loss:0.9453,Test Accuracy:62.1965
Epoch:31,Loss:0.6764,Accuracy:72.9534,Test Loss:0.9475,Test Accuracy:65.0867
Epoch:32,Loss:0.6725,Accuracy:72.6352,Test Loss:0.8022,Test Accuracy:68.5549
Epoch:33,Loss:0.6440,Accuracy:75.0362,Test Loss:0.9749,Test Accuracy:64.9711
Epoch:34,Loss:0.6510,Accuracy:74.4287,Test Loss:0.9438,Test Accuracy:66.8208
Epoch:35,Loss:0.6480,Accuracy:73.9369,Test Loss:0.8713,Test Accuracy:66.2428
Epoch:36,Loss:0.6159,Accuracy:76.0197,Test Loss:0.9768,Test Accuracy:66.3584
Epoch:37,Loss:0.5898,Accuracy:76.9164,Test Loss:0.9000,Test Accuracy:66.3584
Epoch:38,Loss:0.5596,Accuracy:78.1313,Test Loss:0.8711,Test Accuracy:70.0578
Epoch:39,Loss:0.5482,Accuracy:78.3917,Test Loss:0.9857,Test Accuracy:66.3584
Epoch:40,Loss:0.5087,Accuracy:79.9248,Test Loss:1.0123,Test Accuracy:70.5202
Epoch:41,Loss:0.5826,Accuracy:77.0610,Test Loss:0.9282,Test Accuracy:68.7861
Epoch:42,Loss:0.4956,Accuracy:80.3876,Test Loss:1.1306,Test Accuracy:68.3237
Epoch:43,Loss:0.5022,Accuracy:80.9083,Test Loss:0.9783,Test Accuracy:68.3237
Epoch:44,Loss:0.5015,Accuracy:80.6190,Test Loss:1.1273,Test Accuracy:68.5549
Epoch:45,Loss:0.4816,Accuracy:81.4001,Test Loss:0.9927,Test Accuracy:69.7110
Epoch:46,Loss:0.4301,Accuracy:83.0778,Test Loss:0.9987,Test Accuracy:67.6301
Epoch:47,Loss:0.4640,Accuracy:82.5282,Test Loss:1.1025,Test Accuracy:64.9711
Epoch:48,Loss:0.5724,Accuracy:78.5074,Test Loss:1.1148,Test Accuracy:65.8960
Epoch:49,Loss:0.4874,Accuracy:81.8629,Test Loss:1.1687,Test Accuracy:68.2081
Epoch:50,Loss:0.3843,Accuracy:85.2473,Test Loss:1.0302,Test Accuracy:69.8266

可以看到在50个epoch之后训练集准确率到达了85%,但是测试集只有69%,网络出现了过拟合,原因:第一是数据集太小,而且数据分布不是很合理,有的图片多有的图片少,第二,是我网络结构不合理,虽然里边用的批标准化和随机舍弃神经元操作,但是因为参数太多,导致了过拟合,大家可以自己修改网络参数,好了博客到这里就结束了,觉得不错的可以点个赞,源码在github(https://github.com/JohnLeek/Tensorflow-study),可以给个start,谢谢!

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值