20.自定义训练(非编译训练)

我们使用手写数据集mnist

目录

1  导入库

2  导入数据

3  处理数据

4  创建数据集

5  建立模型

6  定义优化器与损失函数

7  定义训练函数

7.1  看一下训练数据

7.2  看一下未训练模型的预测结果

7.3  定义损失值函数

7.4  定义训练步骤

7.5  定义训练函数

8  学习怎么定义评估参数

8.1  均值任意参数

8.2  定义正确率

9  加入评估参数

9.1  更改训练步骤函数

9.2  定义测试步骤函数

9.3  更改训练函数

10  训练




1  导入库

2  导入数据

我们看一下train_image与train_labels

和fashion_mnist是相似的



3  处理数据

图像数据升维,归一化,转换为float32类型

标签数据转换为int64类型

之后以相同的方法处理测试数据

  • 使用model.fit的时候需要repeat(),自定义训练不需要repeat()

4  创建数据集

之后对dataset进行乱序并设置批次

  • 这里我们没有设置repeat(),默认数据集只走一次不会重复

之后再创建测试集

5  建立模型

我们这里可以这样写input_shape,这样写无论图像的尺寸如何都可以让其参与训练,当然如果图像大小不一这样的效果并不好,我们当前使用的mnist的图像大小都为(28,28)



6  定义优化器与损失函数

我们自定义训练是没有编译过程的,所以我们再外部定义好优化器和损失函数

首先定义优化器

之后定义损失函数

我们的损失函数使用的是SparseCategoricalCrossentropy(),这个大写开头的损失函数不需要加参数,如果使用 sparse_categorical_crossentropy 则需要加参数,此处我们使用大写的



7  定义训练函数



7.1  看一下训练数据

dataset在eager模式下是可迭代的,我们使用iter将其图像与标签分离,然后使用next看一下它第一批的图像与标签

  • feature

  • label

7.2  看一下未训练模型的预测结果

我们在未训练的时候先看一下直接用模型预测出的结果

  • 这里使用model(feature)也是对这一批的图像进行预测,但是略有区别,我们后面都使用model()

这一步我们会得到一个大列表,其中有三十二个小列表,列表中有十个元素分别对应十个标签的概率值

由于我们这次是输入了32张图像,所以我们取第一维度(axis=1)最大的值的索引

现在我们再看一下实际的情况

可以说是非常不准的



7.3  定义损失值函数

其中y_是我们的预测结果,之后return,使用我们刚刚定义的loss_func(交叉熵),传入两个参数,第一个参数为真实值,第二个参数为预测值,这两个参数不能写反



7.4  定义训练步骤

传入三个参数model(模型),images(图),labels(标签),之后记录梯度(loss值对模型中可训练参数(model.trainable_variables)的微分,之后使用优化器(我们刚刚定义的)中的apply_gradients(适应梯度)方法,参数为梯度与要改变模型中参数的集合



7.5  定义训练函数

实际上就是把上面的循环步骤放到一个循环中,我们在这里训练10个epoch,给dataset放一个标号命名为batch_num(这个标号有没有都行),当一个批次结束的时候会有一个提示,当一个epoch结束的时候也会有个提示

现在我们还没有加入评估的参数(acc与loss),所以先不进行训练,下面我们看一下怎么加入评估参数



8  学习怎么定义评估参数



8.1  均值任意参数

我们首先创建一个对象

  • Mean是求均值的意思,后面的sth可以是任意字符串,是我们训练时定义的,需要什么就定义什么

这个对象在调用的时候每次会返回传入参数的均值,我们现在传入一个整形数据10

发现当前值就是10

之后我们再传入一个整形数据20

发现返回的值是15(20+10的平均值)

我们可以使用result()来返回当前的结果

我们再传入个列表试一下

发现返回的值是25(10+20+30+40的均值)

我们如果只想使用我们当前的数据,我们需要将m这个对象重置,我们使用方法reset_states()

那么我们传入一个整形数据5试一下

发现它不再与之前的数值进行计算了



8.2  定义正确率

首先先使用tf.keras.metrics.SparseCategoricalAccuracy()创建对象

  • 参数中的字符串依然是随意,我们这里就使用acc

然后我们计算一下上面next出来预测的正确率

  • label与feature都是上面next(iter(dataset))出来的结果

发现正确率很低,这个也确实是我们的实际情况

9  加入评估参数

首先我们定义四个评估对象

9.1  更改训练步骤函数

之后我们更改训练步骤函数

9.2  定义测试步骤函数

9.3  更改训练函数

10  训练

发现可以正常训练,也能显示一个epoch的loss,acc,val_acc与val_loss,在这里我们现不绘制图表,再后面的自定义训练实例中,我们演示图表的绘制方法。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Suyuoa

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值