Seaborn学习

Seaborn学习

Seaborn在matplotlib上进行了封装,提供了许多画图模板。本文整理自https://www.bilibili.com/video/BV1HF411B72n。

整体布局风格设置

import seaborn as sns
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

默认的matplotlib的作图风格

def sinplot(flip=1):
    x = np.linspace(0, 14, 100)
    for i in range(1, 7):
        plt.plot(x, np.sin(x + i*5)*(7-i)*flip)
sinplot()


webp

使用seaborn的默认风格:

sns.set()
sinplot()


webp

Seaborn中有五种默认的主题风格,darkgrid, whitegrid, dark, white, tikcs。

sns.set_style("whitegrid") # whitegrid风格
data = np.random.normal(size=(20,6)) + np.arange(6) / 2
sns.boxplot(data=data)
<AxesSubplot:>


webp

sns.set_style("dark") # dark风格
sinplot()


webp

sns.set_style("white") # white风格
sinplot()


webp

sns.set_style("ticks") # ticks风格
sinplot()


webp

风格细节设置

仅保留x,y轴:

sinplot()
sns.despine()


webp

设置图离轴线的距离:

sns.violinplot(data)
sns.despine(offset=10) # 距离为10
d:\develop\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

webp

隐藏一指定的轴

sinplot()
sns.despine(left=True) # 隐藏y轴


webp

在子图中指定不同的风格:

with sns.axes_style("darkgrid"):
    plt.subplot(211)
    sinplot()
plt.subplot(212)
sinplot(-1)


webp

作图大小风格设置:

sns.set()
sns.set_context("paper")
plt.figure(figsize=(8,6))
sinplot()


webp

sns.set_context("talk")
plt.figure(figsize=(8,6))
sinplot()


webp

sns.set_context("poster")
plt.figure(figsize=(8,6))
sinplot()


webp

sns.set_context("notebook",font_scale=1.5, rc={"lines.linewidth": 2.5})
plt.figure(figsize=(8,6))
sinplot()


webp

set_context()中的参数可以指定粗细,字体大小等。

调色板

  • 分类色板

默认的颜色循环主题

current_palette = sns.color_palette()
sns.palplot(current_palette)


webp

  • 圆形画板

在一个圆形的空间中画出间隔均匀的颜色(饱和度和亮度不变)

sns.palplot(sns.color_palette("hls", 8))


webp

data = np.random.normal(size=(20,8)) + np.arange(8) / 2
sns.boxplot(data=data, palette=sns.color_palette("hls",8))
<AxesSubplot:>


webp

hls_palette()函数控制颜色的亮度和饱和

sns.palplot(sns.hls_palette(8, l=.3, s=.8))


webp

sns.palplot(sns.color_palette("Paired", 10)) # 成对的


webp

使用xkcd颜色来命名颜色:

plt.plot([0,1], [0,1], sns.xkcd_rgb["pale red"], lw=3)
plt.plot([0,1], [0,2], sns.xkcd_rgb["medium green"], lw=3)
plt.plot([0,1], [0,3], sns.xkcd_rgb["denim blue"], lw=3)
[<matplotlib.lines.Line2D at 0x1ae80b948b0>]


webp

colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"]
sns.palplot(sns.xkcd_palette(colors))


webp

  • 连续色板
sns.palplot(sns.color_palette("Blues"))


webp

sns.palplot(sns.color_palette("BuGn_r"))


webp

  • 色调线性变换

饱和度和亮度线性变换

sns.palplot(sns.color_palette("cubehelix",8))


webp

sns.palplot(sns.cubehelix_palette(8, start=.5, rot=-.75))


webp

sns.palplot(sns.cubehelix_palette(8, start=.75, rot=-.15))


webp

light和dark连续调色板

sns.palplot(sns.light_palette("green"))


webp

sns.palplot(sns.dark_palette("purple"))


webp

sns.palplot(sns.light_palette("navy", reverse=True))


webp

x, y = np.random.multivariate_normal([0,0], [[1,-5],[-5,1]], size=300).T
pal = sns.dark_palette("green", as_cmap=True)
sns.kdeplot(x, y, cmap=pal)
C:\Users\MBA\AppData\Local\Temp/ipykernel_12536/1185996389.py:1: RuntimeWarning: covariance is not positive-semidefinite.
  x, y = np.random.multivariate_normal([0,0], [[1,-5],[-5,1]], size=300).T
d:\develop\Anaconda\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(





<AxesSubplot:>


webp

单变量分析绘图

查看特征分布情况:

sns.set(color_codes=True)
np.random.seed(sum(map(ord, "distributions")))
x = np.random.normal(size=100)
sns.distplot(x, kde=False) # kde: 核密度估计
d:\develop\Anaconda\lib\site-packages\seaborn\distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)





<AxesSubplot:>


webp

sns.distplot(x, bins=20, kde=False)
<AxesSubplot:>


webp

画出拟合曲线:

from scipy import stats
x = np.random.gamma(6, size=200)
sns.distplot(x, kde=False, fit=stats.gamma)
<AxesSubplot:>


webp

import pandas as pd
mean, cov = [0,1], [(1,.5), (.5,1)]
data = np.random.multivariate_normal(mean, cov, 200) #根据均值和方差生成数据
df = pd.DataFrame(data, columns=["x", "y"])
df
xy
02.1908732.902961
10.3879013.441322
2-1.3049090.586173
3-0.0168670.907323
40.2849531.189304
.........
195-0.8043380.139381
1961.6743932.735944
197-1.2376340.002766
198-1.0446830.482758
199-0.8901600.042753

200 rows × 2 columns

绘制散点图

sns.jointplot(x="x", y="y", data=df)
<seaborn.axisgrid.JointGrid at 0x1ae80972310>


webp

x,y = np.random.multivariate_normal(mean, cov, 1000).T
with sns.axes_style("white"):
    sns.jointplot(x=x, y=y, kind="hex", color="k") # 数据量较大时使用


webp

回归分析绘图

iris = sns.load_dataset("iris")
sns.pairplot(iris)
<seaborn.axisgrid.PairGrid at 0x1ae80b0e040>


webp

sns.set(color_codes=True)
np.random.seed(sum(map(ord, "regression")))
tips = sns.load_dataset("tips")
tips.head()
total_billtipsexsmokerdaytimesize
016.991.01FemaleNoSunDinner2
110.341.66MaleNoSunDinner3
221.013.50MaleNoSunDinner3
323.683.31MaleNoSunDinner2
424.593.61FemaleNoSunDinner4

regplot和lmplot都可以绘制回归关系,推荐使用regplot()

sns.regplot(x="total_bill", y="tip", data=tips)
<AxesSubplot:xlabel='total_bill', ylabel='tip'>


webp

sns.lmplot(x="total_bill", y="tip", data=tips)
<seaborn.axisgrid.FacetGrid at 0x1ae8286f670>


webp

sns.regplot(x="size", y="tip", data=tips, x_jitter=.05) # x_jitter: 增加抖动
<AxesSubplot:xlabel='size', ylabel='tip'>


webp

多变量分析绘图

np.random.seed(sum(map(ord, "categorical")))
titanic = sns.load_dataset("titanic")
sns.stripplot(x="day", y="total_bill",data=tips)
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

解决重叠问题:

sns.stripplot(x="day", y="total_bill",data=tips, jitter=True) # 向左右偏离
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

sns.swarmplot(x="day", y="total_bill", data=tips)
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

sns.swarmplot(x="day", y="total_bill", hue="sex",data=tips)
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

sns.swarmplot(x="total_bill", y="day",hue="time", data=tips)
<AxesSubplot:xlabel='total_bill', ylabel='day'>


webp

盒图

sns.boxplot(x="day", y="total_bill", hue="time", data=tips)
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

sns.violinplot(x="total_bill", y="day", hue="time", data=tips)
<AxesSubplot:xlabel='total_bill', ylabel='day'>


webp

sns.violinplot(x="day", y="total_bill", hue="sex", data=tips, split=True)
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

sns.violinplot(x="day", y="total_bill", data=tips, inner=None)
sns.swarmplot(x="day", y="total_bill", data=tips, color="w", alpha=.5)
<AxesSubplot:xlabel='day', ylabel='total_bill'>


webp

显示值的集中趋势可以使用条形图

sns.barplot(x="sex", y="survived", hue="class", data=titanic)
<AxesSubplot:xlabel='sex', ylabel='survived'>


webp

点图可以更好地描述变化差异

sns.pointplot(x="sex", y="survived", hue="class", data=titanic)
<AxesSubplot:xlabel='sex', ylabel='survived'>


webp

sns.pointplot(x="class", y="survived", hue="sex", data=titanic, palette={"male":"g", "female":"m"},
markers=["^", "o"], linestyles=["-","--"])
<AxesSubplot:xlabel='class', ylabel='survived'>


webp

多层面板分类图

sns.factorplot(x="day", y="total_bill", hue="smoker", data=tips)
d:\develop\Anaconda\lib\site-packages\seaborn\categorical.py:3717: UserWarning: The `factorplot` function has been renamed to `catplot`. The original name will be removed in a future release. Please update your code. Note that the default `kind` in `factorplot` (`'point'`) has changed `'strip'` in `catplot`.
  warnings.warn(msg)





<seaborn.axisgrid.FacetGrid at 0x1ae82d7d460>


webp

sns.factorplot(x="day", y="total_bill", hue="smoker", data=tips, kind="bar")
d:\develop\Anaconda\lib\site-packages\seaborn\categorical.py:3717: UserWarning: The `factorplot` function has been renamed to `catplot`. The original name will be removed in a future release. Please update your code. Note that the default `kind` in `factorplot` (`'point'`) has changed `'strip'` in `catplot`.
  warnings.warn(msg)





<seaborn.axisgrid.FacetGrid at 0x1ae83fe0940>


webp

sns.factorplot(x="day", y="total_bill", hue="smoker", data=tips, kind="swarm")
d:\develop\Anaconda\lib\site-packages\seaborn\categorical.py:3717: UserWarning: The `factorplot` function has been renamed to `catplot`. The original name will be removed in a future release. Please update your code. Note that the default `kind` in `factorplot` (`'point'`) has changed `'strip'` in `catplot`.
  warnings.warn(msg)





<seaborn.axisgrid.FacetGrid at 0x1ae8406c0d0>


webp

sns.factorplot(x="time", y="total_bill", hue="smoker", col="day", data=tips, kind="box", size=4, aspect=.5)
# size:每个面的高度,aspect:纵横比
d:\develop\Anaconda\lib\site-packages\seaborn\categorical.py:3717: UserWarning: The `factorplot` function has been renamed to `catplot`. The original name will be removed in a future release. Please update your code. Note that the default `kind` in `factorplot` (`'point'`) has changed `'strip'` in `catplot`.
  warnings.warn(msg)
d:\develop\Anaconda\lib\site-packages\seaborn\categorical.py:3723: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)





<seaborn.axisgrid.FacetGrid at 0x1ae83fa9be0>


webp

Facegrid

展示数据集中的一部分

g = sns.FacetGrid(tips, col="time") # time包含两个指标,dinner和lunch


webp

g = sns.FacetGrid(tips, col="time") 
g.map(plt.hist, "tip") # tip的分布情况
<seaborn.axisgrid.FacetGrid at 0x1ae84305ca0>


webp

g = sns.FacetGrid(tips, col="sex", hue="smoker") 
g.map(plt.scatter, "total_bill", "tip", alpha=.7)
g.add_legend() # 添加图例
<seaborn.axisgrid.FacetGrid at 0x1ae84436e80>


webp

g = sns.FacetGrid(tips, row="smoker", col="sex", hue="smoker", margin_titles=True) 
g.map(sns.regplot, "size", "total_bill", color=".3", fit_reg=True, x_jitter=.1)
<seaborn.axisgrid.FacetGrid at 0x1ae844499a0>


webp

g = sns.FacetGrid(tips, col="day", size=4, aspect=.5)
g.map(sns.barplot, "sex", "total_bill")
d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:337: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)
d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:670: UserWarning: Using the barplot function without specifying `order` is likely to produce an incorrect plot.
  warnings.warn(warning)





<seaborn.axisgrid.FacetGrid at 0x1ae846a11f0>


webp

ordered_days = tips.day.value_counts().index
print(ordered_days)
g = sns.FacetGrid(tips, row="day", row_order=ordered_days, size=1.7, aspect=4)
g.map(sns.boxplot, "total_bill")
CategoricalIndex(['Sat', 'Sun', 'Thur', 'Fri'], categories=['Thur', 'Fri', 'Sat', 'Sun'], ordered=False, dtype='category')


d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:337: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)
d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:670: UserWarning: Using the boxplot function without specifying `order` is likely to produce an incorrect plot.
  warnings.warn(warning)





<seaborn.axisgrid.FacetGrid at 0x1ae847ef910>


webp

pal = dict(Lunch="seagreen", Dinner="gray")
g = sns.FacetGrid(tips, hue="time", palette=pal, size=5)
g.map(plt.scatter, "total_bill", "tip", s=50, alpha=.7, linewidth=.5, edgecolor="white")
g.add_legend()
d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:337: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)





<seaborn.axisgrid.FacetGrid at 0x1ae846636d0>


webp

g = sns.FacetGrid(tips, hue="sex", palette="Set1", size=5, hue_kws={"marker":["^","v"]})
g.map(plt.scatter, "total_bill", "tip", s=100, linewidth=.5, edgecolor="white")
g.add_legend()
d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:337: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)





<seaborn.axisgrid.FacetGrid at 0x1ae85a019a0>


webp

with sns.axes_style("white"):
    g = sns.FacetGrid(tips, row="sex", col="smoker", margin_titles=True, size=2.5)
    g.map(plt.scatter, "total_bill", "tip", color="#334488", edgecolor="white", lw=.5)
    g.set_axis_labels("Total bill (US Dollars)", "Tip")
    g.set(xticks=[10, 30, 50], yticks=[2, 6, 10])
    g.fig.subplots_adjust(wspace=.02, hspace=.02) # 子图与子图之间的间隔
d:\develop\Anaconda\lib\site-packages\seaborn\axisgrid.py:337: UserWarning: The `size` parameter has been renamed to `height`; please update your code.
  warnings.warn(msg, UserWarning)

webp

iris = sns.load_dataset("iris")
g = sns.PairGrid(iris)
g.map(plt.scatter)
<seaborn.axisgrid.PairGrid at 0x1ae85980b20>


webp

iris = sns.load_dataset("iris")
g = sns.PairGrid(iris)
g.map_diag(plt.hist) # 对角线
g.map_offdiag(plt.scatter) # 非对角线
<seaborn.axisgrid.PairGrid at 0x1ae86c26ca0>


webp

g = sns.PairGrid(iris, hue="species")
g.map_diag(plt.hist)
g.map_offdiag(plt.scatter)
g.add_legend()
<seaborn.axisgrid.PairGrid at 0x1e1f7886460>


webp

g = sns.PairGrid(tips, hue="size", palette="GnBu_d")
g.map(plt.scatter, s=50, edgecolor="white")
g.add_legend()
<seaborn.axisgrid.PairGrid at 0x1e1f8d30790>


webp

热度图

np.random.seed(0)
sns.set()
uniform_data = np.random.rand(3,3)
print(uniform_data)
heat_map = sns.heatmap(uniform_data)
[[0.5488135  0.71518937 0.60276338]
 [0.54488318 0.4236548  0.64589411]
 [0.43758721 0.891773   0.96366276]]

webp

ax = sns.heatmap(uniform_data, vmin=0.2, vmax=0.5)


webp

normal_data = np.random.randn(3,3)
print(normal_data)
ax = sns.heatmap(normal_data, center=0)
[[ 1.26611853 -0.50587654  2.54520078]
 [ 1.08081191  0.48431215  0.57914048]
 [-0.18158257  1.41020463 -0.37447169]]

webp

flights = sns.load_dataset("flights")
flights.head()
yearmonthpassengers
01949Jan112
11949Feb118
21949Mar132
31949Apr129
41949May121
flights = flights.pivot("month","year","passengers")
print(flights)
ax = sns.heatmap(flights)
year   1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960
month                                                                        
Jan     112   115   145   171   196   204   242   284   315   340   360   417
Feb     118   126   150   180   196   188   233   277   301   318   342   391
Mar     132   141   178   193   236   235   267   317   356   362   406   419
Apr     129   135   163   181   235   227   269   313   348   348   396   461
May     121   125   172   183   229   234   270   318   355   363   420   472
Jun     135   149   178   218   243   264   315   374   422   435   472   535
Jul     148   170   199   230   264   302   364   413   465   491   548   622
Aug     148   170   199   242   272   293   347   405   467   505   559   606
Sep     136   158   184   209   237   259   312   355   404   404   463   508
Oct     119   133   162   191   211   229   274   306   347   359   407   461
Nov     104   114   146   172   180   203   237   271   305   310   362   390
Dec     118   140   166   194   201   229   278   306   336   337   405   432

webp

将值显示在热力图上

ax = sns.heatmap(flights, annot=True, fmt='d')


webp

ax = sns.heatmap(flights, linewidths=.5) # 设置格子间距


webp

ax = sns.heatmap(flights, cmap="YlGnBu")


webp

ax = sns.heatmap(flights, cbar=False) # 不显示cbar


webp

Jan     112   115   145   171   196   204   242   284   315   340   360   417
Feb     118   126   150   180   196   188   233   277   301   318   342   391
Mar     132   141   178   193   236   235   267   317   356   362   406   419
Apr     129   135   163   181   235   227   269   313   348   348   396   461
May     121   125   172   183   229   234   270   318   355   363   420   472
Jun     135   149   178   218   243   264   315   374   422   435   472   535
Jul     148   170   199   230   264   302   364   413   465   491   548   622
Aug     148   170   199   242   272   293   347   405   467   505   559   606
Sep     136   158   184   209   237   259   312   355   404   404   463   508
Oct     119   133   162   191   211   229   274   306   347   359   407   461
Nov     104   114   146   172   180   203   237   271   305   310   362   390
Dec     118   140   166   194   201   229   278   306   336   337   405   432

[外链图片转存中…(img-05YSwF71-1662719715340)]

将值显示在热力图上

ax = sns.heatmap(flights, annot=True, fmt='d')


[外链图片转存中…(img-rsFdk0Fg-1662719715340)]

ax = sns.heatmap(flights, linewidths=.5) # 设置格子间距


[外链图片转存中…(img-RSBFxLdM-1662719715341)]

ax = sns.heatmap(flights, cmap="YlGnBu")


[外链图片转存中…(img-n5iRtOj7-1662719715342)]

ax = sns.heatmap(flights, cbar=False) # 不显示cbar


[外链图片转存中…(img-rxk9b7yu-1662719715342)]

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
强化学习seaborn 是两个不同的主题,但是可以结使用来绘制强化学习的曲。以下是一个使用 seaborn 绘制强化学习曲线的基本示例: 首先,确保你已经安装了 seaborn 库。如果没有安装,可以使用以下命令进行安装: ``` pip install seaborn ``` 然后,导入 seaborn 和其他必要的库: ```python import seaborn as sns import matplotlib.pyplot as plt ``` 假设你有一个强化学习任务,你已经运行了多个实验,并且每个实验记录了每个回合的奖励值。你可以将这些奖励值绘制成曲线,以观察强化学习算法的学习进展。 下面是一个简单的示例代码,它使用 seaborn 绘制了三个实验的奖励曲线: ```python # 假设你有三个实验的奖励数据 experiment1_rewards = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] experiment2_rewards = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] experiment3_rewards = [0, 3, 6, 9, 12, 15, 18, 21, 24, 27] # 创建一个包含所有实验奖励值的数据框 data = { 'Experiment 1': experiment1_rewards, 'Experiment 2': experiment2_rewards, 'Experiment 3': experiment3_rewards } df = pd.DataFrame(data) # 使用 seaborn 绘制曲线 sns.lineplot(data=df) # 显示图形 plt.show() ``` 运行这段代码,你将会得到一个包含三个实验曲线的图形。 这只是一个简单的示例,你可以根据自己的需求调整绘图参数和数据格式等。希望这个示例能够帮助你开始使用 seaborn 绘制强化学习曲线。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值