使用案例
import pandas as pd
import numpy as np
import holoviews as hv
import plotly.graph_objects as go
import plotly.express as pex
import pandas as pd
hv.extension("bokeh")
df=pd.DataFrame({
"性别":["男","男","男","女","女","女"],
"熬夜原因":["打游戏","加班","看剧","打游戏","加班","看剧"],
"额外列":["早起","早起","晚起","中午起","早起","早起"],
"人数":[57,13,30,33,5,62]})
#print(df)
sankey1 = hv.Sankey(df, kdims=["额外列", "熬夜原因"], vdims=["人数"])
sankey1.opts(cmap='Colorblind',label_position='left',
edge_color='熬夜原因', edge_line_width=0,
node_alpha=1.0, node_width=40, node_sort=True,
width=800, height=600, bgcolor="snow",
title="熬夜分布图")
最终结果如下:
如果想实现三列的图
from plotly.offline import init_notebook_mode, iplot
from plotly.graph_objs import *
init_notebook_mode(connected=True)
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
# maximum of 6 value cols -> 6 colors
colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
labelList = []
colorNumList = []
for catCol in cat_cols:
labelListTemp = list(set(df[catCol].values))
colorNumList.append(len(labelListTemp))
labelList = labelList + labelListTemp
# remove duplicates from labelList
labelList = list(dict.fromkeys(labelList))
# define colors based on number of levels
colorList = []
for idx, colorNum in enumerate(colorNumList):
colorList = colorList + [colorPalette[idx]]*colorNum
# transform df into a source-target pair
for i in range(len(cat_cols)-1):
if i==0:
sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
sourceTargetDf.columns = ['source','target','count']
else:
tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
tempDf.columns = ['source','target','count']
sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
# add index for source-target pair
sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
# creating the sankey diagram
data = dict(
type='sankey',
node = dict(
pad = 15,
thickness = 20,
line = dict(
color = "black",
width = 0.5
),
label = labelList,
color = colorList
),
link = dict(
source = sourceTargetDf['sourceID'],
target = sourceTargetDf['targetID'],
value = sourceTargetDf['count']
)
)
layout = dict(
title = title,
font = dict(
size = 10
)
)
fig = dict(data=[data], layout=layout)
return fig
import pandas as pd
import plotly
fig = genSankey(df,cat_cols=['性别','熬夜原因',"额外列"],value_cols='人数',title='Sanky 测试')
iplot(fig)
# 在jupyter里面显示
最终的结果如图
其实上面这幅图很丑,我开始都没意识到,这个图的丑在于每列的节点颜色是一样的,所以我需要调整颜色,其实调整颜色很简单
from plotly.offline import init_notebook_mode, iplot
from plotly.graph_objs import *
init_notebook_mode(connected=True)
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
# maximum of 6 value cols -> 6 colors
colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
labelList = []
colorNumList = []
for catCol in cat_cols:
labelListTemp = list(set(df[catCol].values))
colorNumList.append(len(labelListTemp))
labelList = labelList + labelListTemp
# remove duplicates from labelList
labelList = list(dict.fromkeys(labelList))
# define colors based on number of levels
colorList = []
for idx, colorNum in enumerate(colorNumList):
colorList = colorList + [colorPalette[idx]]*colorNum
# transform df into a source-target pair
for i in range(len(cat_cols)-1):
if i==0:
sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
sourceTargetDf.columns = ['source','target','count']
else:
tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
tempDf.columns = ['source','target','count']
sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
# add index for source-target pair
sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
# creating the sankey diagram
data = dict(
type='sankey',
node = dict(
pad = 15,
thickness = 20,
line = dict(
color = "black",
width = 0.5
),
label = labelList,
#color = colorList 注释掉这一行就可以了,就可以显示每个category的颜色了
),
link = dict(
source = sourceTargetDf['sourceID'],
target = sourceTargetDf['targetID'],
value = sourceTargetDf['count']
)
)
layout = dict(
title = title,
font = dict(
size = 10
)
)
fig = dict(data=[data], layout=layout)
return fig
import pandas as pd
import plotly
fig = genSankey(df,cat_cols=['性别','熬夜原因',"额外列"],value_cols='人数',title='Sanky 测试')
iplot(fig)
## 这个竟然能在jupyter中显示,那么我能存起来
结果如下
如果我要自定义每个category的颜色,同样是修改同样的位置
from plotly.offline import init_notebook_mode, iplot
from plotly.graph_objs import *
import seaborn as sns
init_notebook_mode(connected=True)
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
# maximum of 6 value cols -> 6 colors
colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
labelList = []
colorNumList = []
for catCol in cat_cols:
labelListTemp = list(set(df[catCol].values))
colorNumList.append(len(labelListTemp))
labelList = labelList + labelListTemp
# remove duplicates from labelList
labelList = list(dict.fromkeys(labelList))
# define colors based on number of levels
colorList = []
for idx, colorNum in enumerate(colorNumList):
colorList = colorList + [colorPalette[idx]]*colorNum
# transform df into a source-target pair
for i in range(len(cat_cols)-1):
if i==0:
sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
sourceTargetDf.columns = ['source','target','count']
else:
tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
tempDf.columns = ['source','target','count']
sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
# add index for source-target pair
sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
# creating the sankey diagram
data = dict(
type='sankey',
node = dict(
pad = 15,
thickness = 20,
line = dict(
color = "black",
width = 0.5
),
label = labelList,
color = sns.color_palette("dark",6).as_hex()
#color = colorList 注释掉这一行就可以了,就可以显示每个category的颜色了
),
link = dict(
source = sourceTargetDf['sourceID'],
target = sourceTargetDf['targetID'],
value = sourceTargetDf['count']
)
)
layout = dict(
title = title,
font = dict(
size = 10
)
)
fig = dict(data=[data], layout=layout)
return fig
import pandas as pd
import plotly
fig = genSankey(df,cat_cols=['性别','熬夜原因',"额外列"],value_cols='人数',title='Sanky 测试')
iplot(fig)
## 这个竟然能在jupyter中显示,那么我能存起来
结果如下
这个图怎么既在jupyter中显示并保存呢,我之前尝试过,好像失败了
from plotly.offline import init_notebook_mode, iplot
from plotly.graph_objs import *
import seaborn as sns
init_notebook_mode(connected=True)
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
# maximum of 6 value cols -> 6 colors
colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
labelList = []
colorNumList = []
for catCol in cat_cols:
labelListTemp = list(set(df[catCol].values))
colorNumList.append(len(labelListTemp))
labelList = labelList + labelListTemp
# remove duplicates from labelList
labelList = list(dict.fromkeys(labelList))
# define colors based on number of levels
colorList = []
for idx, colorNum in enumerate(colorNumList):
colorList = colorList + [colorPalette[idx]]*colorNum
# transform df into a source-target pair
for i in range(len(cat_cols)-1):
if i==0:
sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
sourceTargetDf.columns = ['source','target','count']
else:
tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
tempDf.columns = ['source','target','count']
sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
# add index for source-target pair
sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
# creating the sankey diagram
data = dict(
type='sankey',
node = dict(
pad = 15,
thickness = 20,
line = dict(
color = "black",
width = 0.5
),
label = labelList,
color = sns.color_palette("dark",6).as_hex()
#color = colorList 注释掉这一行就可以了,就可以显示每个category的颜色了
),
link = dict(
source = sourceTargetDf['sourceID'],
target = sourceTargetDf['targetID'],
value = sourceTargetDf['count']
)
)
layout = dict(
title = title,
font = dict(
size = 10
)
)
fig = dict(data=[data], layout=layout)
return fig
import pandas as pd
import plotly
fig = genSankey(df,cat_cols=['性别','熬夜原因',"额外列"],value_cols='人数',title='Sanky 测试')
#iplot(fig) # 仅在jupyter中显示,不能自动保存
plotly.offline.plot(fig,filename="./test_sanky_plot.html") # 会自动打开browser的一个窗口显示,但会保存文件
今天又发现了一种用R画漂亮的Sanky图,记录如下
library(ggalluvial)
as.data.frame(Titanic)
dat01<-data.frame(first=c("A","A","C"),
second=c("B","D","E"),
third=c("F","G","F"),
n=c(15,30,20))
dat01$first<-factor(dat01$first,
levels = c("C","A"))
p=ggplot(data=dat01,
aes(axis1=first,axis2=second,axis3=third,
y=n))+
geom_alluvium(aes(fill=second),
#size=3,
#color="white",
width = 0.1,
aes.bind = "flows")+
geom_stratum(fill=c("red","blue","green",'yellow',
'orange',"blue","green"),
#color="white",
#size=3,
width=0.1)+
geom_label(stat = "stratum", aes(label = after_stat(stratum)))+
scale_x_continuous(breaks = c(1,2),
labels = c("first","second"),
expand = expansion(mult = c(0,0)))+
theme_void()
print(p)
结果如下