matplotlib均值和方差图-多组成功率为例-代码

16 篇文章 9 订阅
13 篇文章 5 订阅

matplotlib多组均值和方差图-成功率为例

前言:

虽然主流的强化很少有成功率的柱状图表示,但是在机械臂任务中,还是有不少成功率的展示,因此将这个脚本优化精简分享出来了。
供大家参考一下~

多组实际效果:

在这里插入图片描述

代码:

"""
"@Author    : kaixindelele,
"@Contact   : 
CSDN:	https://blog.csdn.net/hehedadaq;
知乎: 	https://www.zhihu.com/people/heda-he-28
"@Describe  : 对于多任务、多设置实验的成功率可视化,
            读取文本的问题有点复杂,我就直接把数据手动输入到字典中。
"""

import matplotlib.pyplot as plt
import numpy as np

DIV_LINE_WIDTH = 50

# Global vars for tracking and labeling data at load time.
exp_idx = 0
units = dict()
barFontSize = 16
xTicksFontSize = 16
yTicksFontSize = 16
yLabelFontSize = 16
legendFontSize = 16
titleFontSize = 16
errorBarSize = 5


def plot_success(exp_data_dict, task_names, exp_names,
                 title_str='Figure', y_label_str="Success Rate (%)"):
    """
    :param exp_data_dict: [name:list(1,2,3)]
    :param task_names: ['Reaching task', 'Lifting task', ]
    :param exp_names: ['PNO', 'PNR', 'POR']
    :param title_str: 'Fig'
    :param y_label_str: 'Success Rate (%)'
    :return: None
    """
    # 按照实验名迭代,对每个任务画一个柱子
    for exp_index, exp_name in enumerate(exp_names):
        total_mean = []
        total_std = []
        for task_name in task_names:
            mean_list = []
            for m in exp_data_dict[exp_name+'_'+task_name]:
                success_num = m
                mean_list.append([np.round(float(success_num)/1.0, 2)])
            mean = np.mean(np.array(mean_list))
            std = np.std(np.array(mean_list))
            total_mean.append(mean)
            total_std.append(std)

        bar_width = 0.5
        # 有i个实验,在不同任务中的位置。
        x = np.arange(len(task_names)) * (len(exp_names)+1) * bar_width + exp_index * bar_width
        print("x:", x)
        rect_mean = plt.bar(x=x,
                            height=total_mean,
                            width=bar_width,
                            align="center",
                            label=exp_name,
                            )
        rect_std = plt.errorbar(x=x,
                                y=total_mean,
                                yerr=total_std,
                                fmt='o',        # 中心点形状
                                ecolor='r',     # 竖线颜色
                                color='b',      # 横线颜色
                                elinewidth=2,   # 线宽
                                capsize=errorBarSize,   # 横线长度
                                )
        # 给legend赋值字体大小
        plt.legend(loc=0, numpoints=1)
        leg = plt.gca().get_legend()
        text = leg.get_texts()
        plt.setp(text, fontsize=legendFontSize)
        # 给每个柱状图都标上均值,按照规律来。
        for i, y in enumerate(total_mean):
            plt.text(x[i], 1*y + 1.5, '%s' % round(y, 2), ha='center', FontSize=barFontSize)
    # 在x的基础上,加了半个宽度
    x_center_list = [bar_width*(i*(len(exp_names)+1)+(len(exp_names))/2) for i in range(len(task_names))]
    print(x_center_list)
    np.arange(len(task_names)) * (len(exp_names) + 1)

    plt.xticks(x_center_list, task_names, FontSize=xTicksFontSize)
    plt.yticks(FontSize=yTicksFontSize)
    plt.ylabel(y_label_str, FontSize=yLabelFontSize)

    plt.title(title_str, FontSize=titleFontSize)
    plt.show()


def main():
    data_dict = dict()
    # 不同的任务,每个任务有不同的实验设置,中间用_连起来
    data_dict = {"PNO_Reaching task": [77.8, 20, 80],
                 "PNR_Reaching task": [90.0, 89, 29],

                 "PNO_Lifting task": [8.2, 30, 90],
                 "PNR_Lifting task": [69.2, 20, 102],
                 }
    task_names = ['Reaching task', 'Lifting task']
    exp_names = ['PNO', 'PNR']
    plot_success(data_dict,
                 task_names=task_names,
                 exp_names=exp_names)


if __name__ == "__main__":
    main()

单组成功率柱状图:

单组效果图

在这里插入图片描述

"""
"@Author    : kaixindelele,
"@Contact   : 
CSDN:	https://blog.csdn.net/hehedadaq;
知乎: 	https://www.zhihu.com/people/heda-he-28
"@Describe  : 对于多任务、多设置实验的成功率可视化,
            读取文本的问题有点复杂,我就直接把数据手动输入到字典中。
"""

import matplotlib.pyplot as plt
import numpy as np

DIV_LINE_WIDTH = 50

# Global vars for tracking and labeling data at load time.
exp_idx = 0
units = dict()
barFontSize = 16
xTicksFontSize = 16
yTicksFontSize = 16
yLabelFontSize = 16
legendFontSize = 16
titleFontSize = 16
errorBarSize = 5


def plot_success(exp_data_dict, 
                #  task_names, exp_names,
                 title_str='Success Rate (%)', y_label_str="Success Rate (%)"):
    """
    :param exp_data_dict: [name:list(1,2,3)]
    :param task_names: ['Reaching task', 'Lifting task', ]
    :param exp_names: ['PNO', 'PNR', 'POR']
    :param title_str: 'Fig'
    :param y_label_str: 'Success Rate (%)'
    :return: None
    """
    bar_width = 0.6
    # 有i个实验,在不同任务中的位置。
    x_axis_tick = np.arange(len(exp_data_dict))
    values = []
    labels = []
    for key, value in exp_data_dict.items():
        values.append(value[0])
        labels.append(key)
    print("x_axis_tick:", x_axis_tick)
    print("values:", values)
    print("labels:", labels)
    for index in range(len(exp_data_dict)):
        rect_mean = plt.bar(x=x_axis_tick[index],
                            height=values[index],
                            width=bar_width,
                            align="center",
                            label=labels[index],
                            )
    
    # 给legend赋值字体大小
    # plt.legend(loc=0, numpoints=1)
    plt.legend(loc='upper center',
               borderaxespad=0.,
               )
    leg = plt.gca().get_legend()
    text = leg.get_texts()
    plt.setp(text, fontsize=legendFontSize)
    # 给每个柱状图都标上均值,按照规律来。
    for i, y in enumerate(values):
        plt.text(x_axis_tick[i], 1*y + 1.5, '%s' % round(y, 2), ha='center', FontSize=barFontSize)
    
    plt.xticks(x_axis_tick, labels, FontSize=xTicksFontSize)
    # 确保y轴标签处于0-100,每隔20一个,
    y_ticks = np.arange(0, 120, 20)
    print("y_ticks:", y_ticks)
    plt.yticks(y_ticks, FontSize=yTicksFontSize)
    plt.ylabel(y_label_str, FontSize=yLabelFontSize)

    # plt.title(title_str, FontSize=titleFontSize)
    plt.show()


def main():
    data_dict = dict()
    # 确保成功率是%格式!不能小于1
    data_dict = {"dense": [10.0/15.0 * 100.0],
                 "dense2sparse": [14.0/15.0 * 100.0],
                 "sparse": [0.0/15.0 * 100.0],
                 }
    plot_success(data_dict,
                 )


if __name__ == "__main__":
    main()

联系方式:

ps: 欢迎做强化的同学加群一起学习:

深度强化学习-DRL:799378128

欢迎关注知乎帐号:未入门的炼丹学徒

CSDN帐号:https://blog.csdn.net/hehedadaq

极简spinup+HER+PER代码实现:https://github.com/kaixindelele/DRLib

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

hehedadaq

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值