转载请标明作者
首先是狄利克雷分布的三维实现,由于没法保证下面的坐标值加起来唯一,所以我采用了抽样的方法,从dirchidirichlet~(1,1,1)中抽的三维图像的坐标值,我只去前两个作为我的x,y值,代码如下:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy import stats
global LEN
LEN=1000
#抽样
rv=stats.dirichlet.rvs([1,1,1],3000)
#将前两个值作为我的x,y值
rvs=rv[:,0:2]
list=[]
for x in rv:
#计算其对应的pdf值
pdf=stats.dirichlet.pdf(x,alpha=[5,10,7])
list.append(float(pdf))
fig=plt.figure()
ax=fig.gca(projection="3d")
ax.scatter(rvs[:,0],rvs[:,1],list)
plt.show()
由于,我采用的是抽样的方式,所以如果数据少的话,会出现空白部位,得到的图像为:
接下来介绍一下 stick-breaking的实现,一般用于狄利克雷过程中数据的抽样,代码如下 :
import matplotlib.pyplot as plt
import numpy as np
#my stick breaking sampling
#使用的基函数是高斯
#dp~dp(alpha,gaussian(0,1))
def stick_breaking(alpha,sample_num):
x_list=[]
y_list=[]
for x in range(sample_num):
data_x=np.random.normal(0,1,1)
beta=np.random.beta(1,alpha,1)
y=(1-sum(y_list))*beta
x_list.append(data_x)
y_list.append(y)
return x_list,y_list
alpha=1000
sample_num=500#样本数
x_list,y_list=stick_breaking(alpha,sample_num)
plt.stem(x_list,y_list)
plt.show()
得到的图像如下: