BBIT
畅游人工智能之海
——Keras教程之
指标函数
![188a8bebed0528fe9feb4ee437ddf905.png](https://i-blog.csdnimg.cn/blog_migrate/ce2e7fce2b99353c777f21f466bb2f85.png)
Artificial Intelligence
Metrics函数是用于判断模型性能的函数。
Metrics函数与损失函数类似,不同之处在于训练模型时不使用指标函数的结果。我们可以使用任何损失函数作为指标。
指标函数有非常多种类,如精度指标、概率指标、回归指标等等,今天我们来整体概述一下Metrics函数。
compile()中的用法
compile方法会采用一个metrics参数,该参数是Metrics的列表:
model.compile( optimizer='adam', loss='mean_squared_error', metrics=[ metrics.MeanSquaredError(), metrics.AUC(), ])'''指标值在fit()期间显示,并记录到fit()返回的历史对象中。它们也会被model.evaluate()返回在训练期间监控指标的最佳方法是通过TensorBoard'''# 要跟踪特定名称下的指标,可以将name参数传递给度量构造函数:model.compile( optimizer='adam', loss='mean_squared_error', metrics=[ metrics.MeanSquaredError(name='my_mse'), metrics.AUC(name='my_auc'), ])#所有内置指标也可以通过其字符串标识符传递(在这种情况下,使用默认构造函数参数值,包括默认指标名称):model.compile( optimizer='adam', loss='mean_squared_error', metrics=[ 'MeanSquaredError', 'AUC', ])
独立使用
与损失不同,指标是有状态的。使用update_state()方法更新其状态,并使用result()方法查询标量指标结果:
m = tf.keras.metrics.AUC()m.update_state([0, 1, 1, 1], [0, 1, 0, 0])print('Intermediate result:', float(m.result()))m.update_state([1, 1, 1, 1], [0, 1, 1, 0])print('Final result:', float(m.result()))
以下是如何将指标用作简单自定义培训循环的一部分:
accuracy = tf.keras.metrics.CategoricalAccuracy()loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)optimizer = tf.keras.optimizers.Adam()# Iterate over the batches of a dataset.for step, (x, y) in enumerate(dataset): with tf.GradientTape() as tape: logits = model(x) # Compute the loss value for this batch. loss_value = loss_fn(y, logits) # Update the state of the `accuracy` metric. accuracy.update_state(y, logits) # Update the weights of the model to minimize the loss value. gradients = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) # Logging the current accuracy value so far. if step % 100 == 0: print('Step:', step) print('Total running accuracy so far: %.3f' % accuracy.result())
创建自定义指标
作为简单的可调用(无状态)
与loss函数非常类似,任何具有签名metric_fn(y_true,y_pred)的可调用函数返回一个损失数组(输入批处理中的一个示例)都可以作为指标传递给compile()。请注意,对于任何此类指标,都自动支持样本权重。
下面是一个简单的例子:
def my_metric_fn(y_true, y_pred): squared_difference = tf.square(y_true - y_pred) return tf.reduce_mean(squared_difference, axis=-1) # Note the `axis=-1`model.compile(optimizer='adam', loss='mean_squared_error', metrics=[my_metric_fn])
作为Metric的子类(有状态)
并不是所有的指标都可以通过无状态的可调用项来表示,因为在培训和评估期间,每个批次的指标值都是经过评估的,但是在某些情况下,每个批次值的平均值并不是您感兴趣的。
假设您想要计算给定评估数据集的AUC:每批AUC值的平均值与整个数据集的AUC值的平均值不同。
对于这样的度量,您将希望Metric类的子类,该类可以跨批维护状态。很简单:
在init中创建状态变量
更新Update_state()中给定y_true和y_pred的变量
在result()中返回度量结果
清除reset_states()中的状态
下面是一个计算二进制真正数的简单示例:
class BinaryTruePositives(tf.keras.metrics.Metric): def __init__(self, name='binary_true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) self.true_positives = self.add_weight(name='tp', initializer='zeros') def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, tf.bool) y_pred = tf.cast(y_pred, tf.bool) values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) values = tf.cast(values, self.dtype) if sample_weight is not None: sample_weight = tf.cast(sample_weight, self.dtype) values = tf.multiply(values, sample_weight) self.true_positives.assign_add(tf.reduce_sum(values)) def result(self): return self.true_positives def reset_states(self): self.true_positives.assign(0)m = BinaryTruePositives()m.update_state([0, 1, 1, 1], [0, 1, 0, 0])print('Intermediate result:', float(m.result()))m.update_state([1, 1, 1, 1], [0, 1, 1, 0])print('Final result:', float(m.result()))
add_metric() API
在编写自定义层或子类模型的前向传递时,有时可能需要动态记录某些数量,作为指标。在这种情况下,可以使用add_metric()方法。
假设你想要记录一个密集的自定义层的激活的平均值。您可以执行以下操作:
class DenseLike(Layer): """y = w.x + b""" ... def call(self, inputs): output = tf.matmul(inputs, self.w) + self.b self.add_metric(tf.reduce_mean(output), aggregation='mean', name='activation_mean') return output
然后将以“activation_mean”的名称跟踪数量。跟踪的值将是每批指标值的平均值(由aggregation='mean'指定)。
明天我们将具体地学习指标函数,谢谢大家的观看!
Artificial Intelligence
![2e278298a41bc445e2afcbaccd5d1930.png](https://i-blog.csdnimg.cn/blog_migrate/4ce8848d6acf573d4b470f84aca18bf1.jpeg)
更多有趣资讯扫码关注 BBIT