- 概述
前面几节讲的是linear regression的内容,这里咱们再讲一个非常常用的一种模型那就是classification,classification顾名思义就是分类的意思,在实际的情况是非常常用的,例如咱们可以定义房价是否过高,如果房价高于100万,则房价过高,设置成true;如果房价低于100万,则房价不高,target就可以设置成false。这里的target就只有2种,分别只有True和False,而不像咱们的的linear regression那样target是连续的。在实际的应用中,这是有非常广泛的应用的,这一节的第一部分主要是讲如何用TensorFlow来训练一个classifier模型来预测classification problems。第二部分主要解释一下measure classification模型的的方法,那就是ROC curve。在linear regression中咱们知道有MAE,MSE等等一些列的方式来判断咱们的模型的表现怎么样,那么在classification中,MAE和MSE都不适用的,那么咱们用什么measurement来判断咱们的模型好不好呢?这时候就需要介绍咱们的ROC curve了。
- TensorFlow应用之Classification
如果咱们的target只有2个(True/False 或者 1/0等等),这种情况咱们一般称之为binary classification problem;如果咱们的target的数量大于2,咱们一般称之为multi_class classification problem。这两种方式无论是哪一种,在咱们用TensorFlow训练的时候,它的的API都是一样的,只是multi-class需要在定义模型的的时候设置一个n_classes参数而已,其他都一样。另外的建模过程跟前面章节说的一样,这一节主要介绍一下他在TensorFlow的应用中跟linear regression的区别,所以我就不会展示整个建模的过程,只会展示他们的不同。第一个不同就是模型定义的时候不同,那么现在来看一下吧
linear_classifier = tf.estimator.LinearClassifier(feature_columns =