python variable shape 不匹配,为什么我会遇到Keras形状不匹配的情况?

I am following a Keras mnist example for beginners. I have tried to change the labels to suit my own data which has 3 distinct text classifications. I am using "to_categorical" to achieve this. The shape looks right to me, but "fit" gets an error:

train_labels = keras.utils.to_categorical(train_labels, num_classes=3)

print(train_images.shape)

print(train_labels.shape)

model = keras.Sequential([

keras.layers.Flatten(input_shape=(28, 28)),

keras.layers.Dense(128, activation=tf.nn.relu),

keras.layers.Dense(3, activation=tf.nn.softmax)

])

model.compile(optimizer='adam',

loss='sparse_categorical_crossentropy',

metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5)

(7074, 28, 28)

(7074, 3)

Blockquote

Blockquote

Traceback (most recent call last): File

"C:/Users/lawrence/PycharmProjects/tester2019/KeraTest.py", line 131,

in

model.fit(train_images, train_labels, epochs=5) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py",

line 1536, in fit

validation_split=validation_split) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py",

line 992, in _standardize_user_data

class_weight, batch_size) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py",

line 1154, in _standardize_weights

exception_prefix='target') File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py",

line 332, in standardize_input_data

' but got array with shape ' + str(data_shape)) ValueError: Error when checking target: expected dense_1 to have shape (1,) but got

array with shape (3,)

解决方案

You need to use categorical_crossentropy instead of sparse_categorical_crossentropy as loss since your labels are one hot encoded.

Alternatively, you can use sparse_categorical_crossentropy if you do not one hot encode the labels. In that case, the labels should have shape (batch_size, 1).

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值