记录一次使用tensorboard画图实例
# 导入包
from torch.utils.tensorboard import SummaryWriter
# 实例化内容
writer = SummaryWriter('./output/log')
for i in range(100):
# 分别是 题目,数值,x坐标
writer.add_scalar(tag, scalar_value, global_step=None)
writer.close()
查看
在终端输入
tensorboard --logdir=./output/log
打开浏览器复制链接即可查看
常见问题
1.该报错为输入的scalar格式不对,可能为ndarray格式,需要进行转换
scalar should be 0d
2.图片显示不全取消第二个√
科研绘图
需要有图例,但是目前tensorboard不能添加,所以需要将数据进行下载
保存为csv文件后,获取数据
根据自己数据特点绘图
此处代码参考:https://www.pianshen.com/article/316476506/
import pandas as pd
import matplotlib.pyplot as plt
from io import StringIO
net2 = pd.read_csv('run_2-tag-loss.csv', usecols=['Step', 'Value'])
plt.plot(net2.Step, net2.Value, lw=1.5, label='Net-2', color='pink')
net3 = pd.read_csv('run_3-tag-loss.csv', usecols=['Step', 'Value'])
plt.plot(net3 .Step, net3 .Value, lw=1.5, label='Net-3', color='green')
net4 = pd.read_csv('run_4-tag-loss.csv', usecols=['Step', 'Value'])
plt.plot(net4 .Step, net4 .Value, lw=1.5, label='Net-4', color='yellow')
plt.legend(loc=0)
plt.show()
就可以放到论文里面了
具有标准差的论文绘图
# 导入库函数
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('ggplot')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 平滑处理,类似tensorboard的smoothing函数。
# 此处的x y需要根据自己文件列名修改,权重可以自己设定
def smooth(read_path, save_path, file_name, x='timestep', y='reward', weight=0.75):
data = pd.read_csv(read_path + file_name)
scalar = data[y].values
last = scalar[0]
smoothed = []
for point in scalar:
smoothed_val = last * weight + (1 - weight) * point
smoothed.append(smoothed_val)
last = smoothed_val
save = pd.DataFrame({x: data[x].values, y: smoothed})
save.to_csv(save_path + 'smooth_'+ file_name)
# 平滑预处理原始reward数据
smooth(read_path='./BipedalWalker-v3/', save_path='./BipedalWalker-v3/', file_name='PPO_BipedalWalker-v3_log_210.csv')
smooth(read_path='./BipedalWalker-v3/', save_path='./BipedalWalker-v3/', file_name='PPO_BipedalWalker-v3_log_310.csv')
smooth(read_path='./BipedalWalker-v3/', save_path='./BipedalWalker-v3/', file_name='PPO_BipedalWalker-v3_log_410.csv')
# 读取平滑后的数据与原始数据
df1 = pd.read_csv('./BipedalWalker-v3/smooth_PPO_BipedalWalker-v3_log_210.csv') #[1100: 1200]
df2=pd.read_csv('./ori/train_data.csv')
# 拼接到一起
df = df1.append(df2)
# 重新排列索引
df.index = range(len(df))
print(df)
# 设置图片大小
plt.figure(figsize=(15, 10))
# 画图,同理xy需要改成自己文件的
sns.lineplot(data=df, x="timestep", y="reward")