tensorflow 里metrics,Tensorflow:如何在多类分类中使用tf.keras.metrics?

在TensorFlow中进行多类别分类时遇到了形状不兼容的问题。原因为使用了适用于二分类的metrics,而非多分类。为解决此问题,自定义了一个CategoricalTruePositives指标,并使用categorical_crossentropy损失函数进行多类别分类。经过训练,模型能够得到每个类别的真正例数。
摘要由CSDN通过智能技术生成

I want to use some of these metrics when training my neural network:

METRICS = [

keras.metrics.TruePositives(name='tp'),

keras.metrics.FalsePositives(name='fp'),

keras.metrics.TrueNegatives(name='tn'),

keras.metrics.FalseNegatives(name='fn'),

keras.metrics.Precision(name='precision'),

keras.metrics.Recall(name='recall'),

keras.metrics.CategoricalAccuracy(name='acc'),

keras.metrics.AUC(name='auc'),

]

BATCH_SIZE = 1024

SHUFFLE_BUFFER_SIZE = 4000

train_dataset = tf.data.Dataset.from_tensor_slices((sent_vectors, labels))

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)

model = tf.keras.Sequential()

model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(embed_dim)))

for units in [256, 256]:

model.add(tf.keras.layers.Dense(units, activation='relu'))

model.add(tf.keras.layers.Dense(4, activation='softmax'))

model.compile(optimizer='adam',

loss='sparse_categorical_crossentropy',

metrics=METRICS)

model.fit(

train_dataset,

epochs=100)

But I get Shapes (None, 4) and (None, 1) are incompatible. I believe this is because I am doing multiclass classification on 4 classes but the metrics are calculated based on binary classification. How do I adjust my code for multiclass classification?

Update: I am interested in gathering the metrics during the learning process like in Tensorflow Imbalanced Classification, not just at the end of the fitting process.

Additional infos:

My input data are numpy arrays with the shape sent_vectors.shape = (number_examples, 65, 300) and labels=(number_examples, 1). I have 4 labels: 0-3.

Stacktrace:

ValueErrorTraceback (most recent call last)

in

1 model.fit(

2 train_dataset,

----> 3 epochs=10)

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)

726 max_queue_size=max_queue_size,

727 workers=workers,

--> 728 use_multiprocessing=use_multiprocessing)

729

730 def evaluate(self,

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)

322 mode=ModeKeys.TRAIN,

323 training_context=training_context,

--> 324 total_epochs=epochs)

325 cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)

326

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)

121 step=step, mode=mode, size=current_batch_size) as batch_logs:

122 try:

--> 123 batch_outs = execution_function(iterator)

124 except (StopIteration, errors.OutOfRangeError):

125 # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)

84 # `numpy` translates Tensors to values in Eager mode.

85 return nest.map_structure(_non_none_constant_value,

---> 86 distributed_function(input_fn))

87

88 return execution_function

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)

455

456 tracing_count = self._get_tracing_count()

--> 457 result = self._call(*args, **kwds)

458 if

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值