tensorflow2.0使用自带的函数求精准率和召回率(解决Shapes (None, 10) and (None, 1) are incompatible)

本代码使用的是cifar10数据集,所以有十个类别
废话不多说,直接给代码吧

import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers,metrics
(x_train, y_train), _ = datasets.cifar10.load_data()

def procession(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.squeeze(y)
    y = tf.one_hot(y, depth=10)

    return x, y
model = Sequential([
    layers.Flatten(input_shape=(32, 32, 3)),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10, activation='softmax')
])
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000).map(procession).batch(128)
# model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=['accuracy'])
model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=[metrics.Recall()])
# model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=[metrics.Precision()])
model.fit(train_db, epochs=5)

出现Shapes (None, 10) and (None, 1) are incompatible的原因是:
x通过模型之后会得到一个shape为(None,10)的数据, 而y因为没有进行one_hot编码,y.shape=(None, 1),形状不同所以不能进行计算
没有对标签y_train进行one_hot编码,但是单单进行one_hot编码也是不够的,因为进行one_hot编码之后y_train.shape = (None, 1, 10)就会报错Shapes (None, 10) and (None, 1,10) are incompatible, 所以在对y_train进行处理时,通过tf.squeeze(y_train)是的y_train.shape = (None, 10),这样子就可以进行计算了。

评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值