在建立模型的损失函数时,直接使用的tensorflow keras自带的MSE函数,传入的是3D张量,但是在训练的过程中,报错ValueError: operands could not be broadcast together with shapes。
查了形状方面不匹配,但是我把模型结构图片展示出来,并没有发现形状上有什么不对。考虑到是fit函数训练时出错,新加的代码只有损失那边,由于我的数据是时间序列数据所以输入时3D张量(samples, time_step, features),模型的输出也是3D张量,直接调用loss = MSE(inputs, outputs), Mode.add_loss(loss),就直接报错了。
所以直接测试了关于MSE函数,分别使用3D张量、2D张量、向量来进行测试,看MSE输入的结果是什么形状。 首先看一下MSE与MAE的定义。
1. MSE(Mean Squared Error)均方误差
2. RMSE(Root Mean Squared Error)
3. MAE(Mean Absolute Error)平均绝对误差
import numpy as np
import tensorflow as tf
import keras as K
from tensorflow.keras.losses import mse
tf.compat.v1.enable_eager_execution()
#3D张量
a = tf.constant([[[1],[2],[3]]], dtype='float32')
b = tf.constant([[[7],[4],[9]]], dtype='float32')
c = mse(a, b)
print(c)
#向量
d = tf.reshape(a, [-1])
e = tf.reshape(b, [-1])
f = mse(d, e)
print(f)
#2D张量
g = tf.reshape(a, [-1, 1])
h = tf.reshape(b, [-1, 1])
o = mse(g, h)
print(o)
#输出
tf.Tensor([[36. 4. 36.]], shape=(1, 3), dtype=float32)
tf.Tensor(25.333334, shape=(), dtype=float32)
tf.Tensor([36. 4. 36.], shape=(3,), dtype=float32)
#自己的例子
nu_inputs = tf.reshape(inputs, [-1,1])
nu_outputs = tf.reshape(outputs, [-1,1])
print(nu_inputs.shape)
print(nu_outputs.shape)
loss = mse(nu_inputs, nu_outputs)
loss = tf.reshape(loss, [-1])
loss = tf.reduce_sum(loss)
print(loss.shape)
#加入损失函数
lstm_endecoder.add_loss(loss)
根据代码给出的结果,我们可以看到三种输出都是张量形式,其中3D张量输出为2D张量,向量输出为标量,2D张量输出为向量,都向下降了一个维度。
其中输出中3D张量元素和向量元素的和除以元素的个数3,就是标量的结果。(36+4+36)/ 3 = 25.333334,最终我们要得到标量的结果。如果我们使用3D张量和2D张量的形式,最后需要
进行处理,得到一个元素和,使之和标量的结果相同。
最后的代码是自己的一个例子,inputs和outputs是网络的输入和输出,首先我把它们转化为2D张量,计算MSE之后,直接作为loss传入网络,结果还是形状不匹配报错。
所以代码中,我在使用2D张量计算了MSE误差之后,将loss形状改为向量,然后求和,传入网络,结果就没有错误。
这个错误一定要注意,目前我认为是无论传入MSE函数是什么形状,一定要保证计算完MSE之后,将结果变为一个标量。也有可能是其他的问题,有更好的解释,欢迎大家指正,共同学习。
参考资料
- 评估指标——均方误差(MSE)、平均绝对误差(MAE)https://blog.csdn.net/Hachi_Lin/article/details/93884333?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.edu_weight&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.edu_weight
-
均方误差(MSE)根均方误差(RMSE)平均绝对误差(MAE)https://blog.csdn.net/xiongchengluo1129/article/details/79155550
-
图像质量评估指标 | MAE | MSE | PSNR | SSIM https://blog.csdn.net/stone_fall/article/details/89389269