public sealed class DirichletDistribution : MultivariateContinuousDistribution
{
private Vector alpha;
private double sum = double.NaN;
private double scale;
private Vector mean;
private SymmetricMatrix covarianceMatrix;
private void DoPrecomputations()
{
if (double.IsNaN(this.sum))
{
this.sum = this.alpha.GetSum();
this.scale = 0.0 - GammaFunctions.LogGamma(this.sum);
this.mean = Vector.Create(this.alpha.Length);
for (int i = 0; i < this.alpha.Length; i++)
{
this.scale += GammaFunctions.LogGamma(this.alpha.GetValue(i));
this.mean[i] = this.alpha[i] / this.sum;
}
}
}
protected override void FillRandomVariateCore(System.Random random, Vector sample)
{
for (int i = 0; i < base.Order; i++)
{
sample[i] = GammaDistribution.GetRandomVariate(random, this.alpha[i]);
}
sample.Multiply(1.0 / sample.GetSum());
}
public DirichletDistribution(Vector parameters)
: base(Vector.SafeLength(parameters))
{
if (parameters == null)
{
ThrowException.ArgumentNull("parameters");
}
if (parameters.Length == 0)
{
ThrowException.ArgumentOutOfRange("parameters");
}
this.alpha = parameters;
}
public DirichletDistribution(Matrix data)
: base((!(data == null)) ? data.RowCount : 0)
{
if (data == null)
{
ThrowException.ArgumentNull("variable");
}
this.alpha = data.ApplyToColumns(Stats.Mean);
}
public DirichletDistribution(NumericalVariable[] variables)
: base((variables != null) ? variables.Length : 0)
{
if (variables == null)
{
ThrowException.ArgumentNull("variable");
}
for (int i = 0; i < variables.Length; i++)
{
if (variables[i] == null)
{
throw new ArgumentException();
}
}
if (variables.Length == 0)
{
ThrowException.ArgumentOutOfRange("variables");
}
this.alpha = new GeneralVector(Array.ConvertAll(variables, Stats.Mean), reuseComponentArray: true);
}
public override double LogProbabilityDensityFunction(Vector x)
{
if (x == null)
{
ThrowException.ArgumentNull("x");
}
bool flag;
flag = x.Length == this.alpha.Length - 1;
if (x.Length != this.alpha.Length && !flag)
{
ThrowException.LengthMismatch("x");
}
this.DoPrecomputations();
double num;
num = 0.0;
double num2;
num2 = 0.0;
for (int i = 0; i < x.Length; i++)
{
double value;
value = x.GetValue(i);
if (value <= 0.0 || value >= 1.0)
{
return 0.0;
}
num += (this.alpha[i] - 1.0) * Math.Log(value);
num2 += value;
}
if (flag)
{
if (num2 >= 1.0)
{
return 0.0;
}
num += (this.alpha[this.alpha.Length - 1] - 1.0) * Math.Log(1.0 - num2);
}
else if (Math.Abs(num2 - 1.0) > 1.4901161193847656E-08)
{
return 0.0;
}
return this.scale + num;
}
public Vector GetParameters()
{
return this.alpha.ToGeneralVector();
}
public override Vector GetMeans()
{
if (this.mean == null)
{
this.DoPrecomputations();
this.mean = this.alpha.ToGeneralVector().Multiply(1.0 / this.sum);
}
return this.mean;
}
public override SymmetricMatrix GetVarianceCovarianceMatrix()
{
if (this.covarianceMatrix == null)
{
this.DoPrecomputations();
this.covarianceMatrix = Matrix.CreateSymmetric(this.alpha.Length);
double num;
num = 1.0 / (this.sum * this.sum * (this.sum + 1.0));
this.covarianceMatrix.AddOuterProduct(0.0 - num, this.alpha);
this.covarianceMatrix.GetDiagonal().Add(this.sum * num, this.alpha);
}
return this.covarianceMatrix;
}
}
如果对您有帮忙,非常感谢您支持一下创造者的付出!
感谢支持技术分享,请扫码点赞支持:
技术合作交流qq:2401315930