import argparse
import matplotlib.pyplot as plt
import pandas as pd
def main():
# 在终端中使用指定路径
# parser = argparse.ArgumentParser(description='Plot Logs')
# parser.add_argument('-path', type=str, required=True, help="Path to '.csv' file.")
# df = pd.read_csv(parser.parse_args().path)
path = 'train_records.csv'
df = pd.read_csv(path)
epoch = df['epoch'] # 读取cvs中的epoch
train_loss = df['train_loss\t'] # 读取cvs中的train_loss下的值
val_loss = df['val_loss\t'] # 读取cvs中的val_loss下的值
# 画三个图
# psnr = df['psnr\t']
# ssim = df['ssim\t']
# fig, axs = plt.subplots(1, 3, figsize=(15, 3.5))
#
# axs[0].plot(epoch, train_loss, label='train_loss')
# axs[0].plot(epoch, val_loss, label='val_loss')
# axs[0].set(xlabel='Epoch', ylabel='Loss',
# title='Train (Val) Loss Analysis')
# axs[0].legend()
#
# axs[1].plot(epoch, psnr)
# axs[1].set(xlabel='Epoch', ylabel='PSNR (dB)',
# title='PSNR Analysis')
#
# axs[2].plot(epoch, ssim)
# axs[2].set(xlabel='Epoch', ylabel='SSIM',
# title='SSIM Analysis')
# 画一个图,自适应图的大小
fig, axs = plt.subplots(1, 1)
axs.plot(epoch, train_loss, label='train_loss')
axs.plot(epoch, val_loss, label='val_loss')
axs.set(xlabel='Epoch', ylabel='Loss',
title='Train (Val) Loss Analysis')
axs.legend()
plt.tight_layout()
plt.show()
if __name__ == '__main__':
main()
Python画train和val的loss
最新推荐文章于 2024-06-29 16:53:05 发布