Tensorflow 2.5 model.evaluate报错Invalid argument: required broadcastable shapes at loc(unknown)

Tensorflow 2.5使用model.evaluate进行模型评估时报错Invalid argument: required broadcastable shapes at loc unknown

⚡插播一条老家自产的糖心苹果,多个品种,欢迎选购!有问题随时私信我⚡:🍎🍎来自雪域高原的馈赠——海拔2000米的大凉山高原生态糖心苹果,欢迎选购!!🍎🍎
在这里插入图片描述

大凉山高原生态糖心苹果

1.软件环境⚙️

Windows10 教育版64位
Python 3.6.3
Tensorflow-GPU 2.5.0
CUDA 11.1

2.问题描述🔍

我们在模型训练完时,都需要对模型的性能进行评估。而在Tensorflow.Keras中,往往通过.flow_from_directory函数读入本地的图片,然后使用model.evaluate对模型进行精度评估:

test_datagen = ImageDataGenerator(preprocessing_function=preprocessing_function)

test_generator = test_datagen.flow_from_directory(test_dir,
                                                  shuffle=False,
                                                  target_size=(299,299),
                                                  batch_size=32)

print("================开始模型评估======================")
model_evaluation = model.evaluate(test_generator, verbose=1)

比如我们这边有一个评估数据集val-fewer-sample,该数据集中包含dogcat两类:
在这里插入图片描述
这些样本已经被打上了正确的标签(即文件夹名),我们训练出来的分类器对这些样本进行预测,如果标签对得上,那么该图片预测正确。
如果你训练的时候使用的是softmax那么不会有问题,但二分类问题我更喜欢用sigmoid,这个时候,如果是Tensorflow 2.5,就会出现报错:

Invalid argument: required broadcastable shapes at loc(unknown)

2022-09-01 10:31:58.676170: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-09-01 10:31:58.676482: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-09-01 10:31:58.676653: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "C:\Program Files\JetBrains\PyCharm 2020.1\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2020.1\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "E:/Code/Python/classification model evaluation/model_evaluation_sigmoid.py", line 71, in <module>
    model_evaluation = model.evaluate(test_generator, verbose=1, workers=4, return_dict=True)
  File "C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1489, in evaluate
    tmp_logs = self.test_function(iterator)
  File "C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\tensorflow\python\eager\def_function.py", line 957, in _call
    filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
  File "C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\tensorflow\python\eager\function.py", line 1961, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\tensorflow\python\eager\function.py", line 596, in call
    ctx=ctx)
  File "C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  required broadcastable shapes at loc(unknown)
	 [[node LogicalAnd (defined at E:/Code/Python/classification model evaluation/model_evaluation_sigmoid.py:71) ]]
	 [[assert_greater_equal_1/Assert/AssertGuard/else/_29/assert_greater_equal_1/Assert/AssertGuard/Assert/data_0/_65]]
  (1) Invalid argument:  required broadcastable shapes at loc(unknown)
	 [[node LogicalAnd (defined at E:/Code/Python/classification model evaluation/model_evaluation_sigmoid.py:71) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_5740]
Function call stack:
test_function -> test_function

可以看到,Tensorflow 报错原因是tensorshape不一致,这就很尴尬了,为什么softmax训练出来的模型可以正常评估,但是sigmoid训练出来的模型就报错呢?
在这里插入图片描述

3.解决方法🐡

经过查询,发现是因为.flow_from_directory中的class_mode参数的默认值是categorical,而我们使用的是sigmoid进行训练:

  def flow_from_directory(self,
                          directory,
                          target_size=(256, 256),
                          color_mode='rgb',
                          classes=None,
                          class_mode='categorical',
                          batch_size=32,
                          shuffle=True,
                          seed=None,
                          save_to_dir=None,
                          save_prefix='',
                          save_format='png',
                          follow_links=False,
                          subset=None,
                          interpolation='nearest'):
    """Takes the path to a directory & generates batches of augmented data."""

因此需要将class_mode修改为binary才能用sigmoid适配,即:

test_generator = test_datagen.flow_from_directory(test_dir,
                                                  shuffle=False,
                                                  target_size=(299,299),
                                                  batch_size=32)
# 修改为:
test_generator = test_datagen.flow_from_directory(test_dir,
                                                  shuffle=False,
                                                  target_size=(299,299),
                                                  class_mode='binary',
                                                  batch_size=32)

4.结果预览🤔

修改完class_mode之后,发现模型评估可以正常运行了:

Found 12226 images belonging to 2 classes.
================开始模型评估======================
2022-09-01 13:47:56.720307: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2022-09-01 13:47:59.106484: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudnn64_8.dll
2022-09-01 13:48:00.179758: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8201
C:\Users\Anaconda3\envs\tf2.5\lib\site-packages\PIL\Image.py:976: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
2022-09-01 13:48:02.140887: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublas64_11.dll
2022-09-01 13:48:02.827805: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublasLt64_11.dll
 42/192 [=====>........................] - ETA: 32s - loss: 0.5887 - accuracy: 0.9371 - precision: 0.8920 - recall: 0.9420


渣男!都看到这里了,还不赶紧点赞评论收藏走一波?

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

任博啥时候能毕业?

有用的话,请博主喝杯咖啡吧!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值