Compile & Fit


写出来是仅供自己观看的,记一些笔记,大佬请绕道,跳过即可

标准流程

通常,在神经网络的学习中,我们要对数据进行training的时候,有一套最基础的标准的流程进行设置。

  1. 设定epoch;
  2. 输入训练集(db);
  3. 设置loss函数;
  4. 计算更新梯度(优化器);
  5. 测试;
# 标准的training流程(手写数字)
for epoch in range (50):
	for step, (x, y) in enumerate(db):
		# [b, 28, 28] => [b, 784]
		x = tf.reshape(x, (-1, 28*28))
		
		with tf.GradientTape() as tape:
			# [b, 784] => [b, 10](神经网络输出)
			out = network(x)
			# [b] => [b, 10]
			y_onehot = tf.onehot(y, depth=10)
			# [b](用标准loss函数计算,或自己定义loss函数)
			loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=Ture))
		
		# compute the gradient(计算梯度)
		grads = tape.gradient(loss, network.trainable_variables)
		# update grad parameter(更新梯度)
		optimizer.apply_gradients(zip(grads, network.trainable_variables))
	
	total_correct = 0
	total_num = 0
	for x, y in db_test:
		# [b, 28, 28] => [b, 784]
		x = tf.reshape(x, (-1, 28*28))
		out = network(x)
		# out => prob(用softmax转换成(0,1),得出可能性,argmax找出位置即第几个,与手写数字对应)
		prob = tf.nn.softmax(out, axis=1)
		pred = tf.argmax(prob, axis=1)
		# pred:[b]
		# y:[b]
		# Ture: equal, False: not equal(判断正误)
		correct = tf.equal(pred, y)
		correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
		total_correct += int(correct)
		total_num += x.shape[0]
	acc = total_correct / total_num

由于训练流程大多是相近的,因此,keras提供了一个API,通过compile & fit 来完成标准化流程的简单写法。

compile & fit

通过compile & fit各参数的设定,使用简单的方法完成之前复杂的步骤,具体参数设定如下列代码所示:

# 优化器选择Adam, lr=0.01
# 选择loss函数CategoricalCrossentropy
network.compile(optimizer=optomizers.Adam(lr=0.01),
				loss=tf.losses.CategoricalCrossentropy(from_logits=Ture)
				)
				
# 设置epoch, db
# 设置测试(验证)集db_test
# 循环2次db,进行一次val
network.fit(db, epoch, validation_data=db_test, validation_freq=2)

通过训练后,将出现两个loss和两个acc,loss和accuracy是训练集的loss和准确率;val_loss和val_accuracy是测试集的loss和acc。
训练结果
写出来是仅供自己观看的,记一些笔记,大佬请绕道,跳过即可

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值