import matplotlib.pyplot as plt
import json
from collections import defaultdict
# 读取JSON文件,按epoch分组并计算每个epoch的平均loss
def process_json_file(file_path):
epoch_losses = defaultdict(list)
with open(file_path, 'r') as f:
for line in f:
data = json.loads(line)
epoch = data.get("epoch")
loss = data.get("loss")
if epoch is not None and loss is not None:
epoch_losses[epoch].append(loss)
# 计算每个epoch的平均loss
epochs = sorted(epoch_losses.keys())
average_losses = [sum(epoch_losses[epoch]) / len(epoch_losses[epoch]) for epoch in epochs]
return epochs, average_losses
# JSON文件路径,根据你的实际路径进行修改
json_file_path = "/home/znck/PycharmProjects/mmselfsup-main/mmselfsup-main/checkpoints/20230920_231524/vis_data/scalars.json"
# 处理JSON文件并获取数据
epochs, average_losses = process_json_file(json_file_path)
# 绘制曲线图
plt.figure(figsize=(8, 6))
plt.plot(epochs, average_losses, linestyle='-')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('Average Loss Over Epochs')
plt.grid(True)
plt.show()