读取TensorBoard生成的events文件并绘图

文章讲述了如何使用TensorBoard的EventAccumulator处理留一人交叉验证实验中的事件文件,提取每个个体的最高、最低及平均准确率,并通过可视化图表展示结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

写在前面

因为题主用的留一人交叉验证方法,故在一个tensorboard文件下会有多个人的结果文件,想要快速获取各个人最高的准确率以及平均准确率,用代码读取events文件明显是最快的
参考:https://blog.csdn.net/zywvvd/article/details/88865416

具体实现

  1. 先看看events文件里有什么key,然后看看其值
from tensorboard.backend.event_processing import event_accumulator
path = ""
ea = event_accumulator.EventAccumulator(path)
ea.Reload()
print(ea.scalars.Keys())
val_acc = ea.scalars.Items("spatial/val/accuracy_epoch")
train_acc = ea.scalars.Items("spatial/train/accuracy_epoch")
  1. 这边为留一人交叉验证的方法,为了得到每个人最高的准确率
import os
from tensorboard.backend.event_processing import event_accumulator
import numpy as np
import matplotlib.pyplot as plt

def get_perno_acc(path):
    ea = event_accumulator.EventAccumulator(path)
    ea.Reload()
    val_acc = ea.scalars.Items("spatial/val/accuracy_epoch")
    max_val_acc = max([i.value for i in val_acc])
    return max_val_acc


def get_ex_acc(dir_path):
    for i in dir_path.split("/")[-1].split("_"):
        if "seed" in i:
            break
        print(i, end=" ")
    print()
    perno_acc = dict()
    perno = os.listdir(dir_path)
    perno.sort(key=lambda x: int(x.split("_")[-1]))
    for per_no in perno:
        per_path = os.path.join(dir_path, per_no)
        event_path = os.path.join(
            per_path, [i for i in os.listdir(per_path) if "events" in i][0]
        )
        # print(event_path)
        perno_acc[int(per_no[6:])] = get_perno_acc(event_path)
    print(perno_acc)
    print(f"ACC Max:{np.max(list(perno_acc.values()))}")
    print(f"ACC Min:{np.min(list(perno_acc.values()))}")
    print(f"ACC Mean:{np.mean(list(perno_acc.values()))}")
    return perno_acc
  1. 然后是绘成柱状图方便观看
# task_spatial_de_4
subject_lists = [ i for i in task_spatial_de_4.keys()]
print(subject_lists)
acc = [task_spatial_de_4[i] for i in subject_lists]
plt.figure(figsize=(12, 5))
plt.bar(subject_lists, acc)
plt.grid()
plt.xticks(subject_lists)
plt.yticks(np.arange(0, 1.1, 0.1))
# 显示每个数值
for x, y in zip(subject_lists, acc):
    plt.text(x, y + 0.01, f"{y:.2f}", ha="center", va="bottom", fontsize=10)
plt.xlabel("Participant")
plt.ylabel("Accuracy")
plt.title("Task Spatial DE 4 Participant Accuracy")
# 调整间距
plt.tight_layout()
plt.show()

具体图什么的因为比较私密就不放了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值