public sealed class MultivariateNormalDistribution : MultivariateContinuousDistribution
{
private Vector mean;
private SymmetricMatrix covarianceMatrix;
private Matrix factor;
private double scale;
private void DoPrecomputations()
{
if (this.factor != null)
{
return;
}
if (this.factor == null)
{
bool flag;
flag = false;
try
{
this.factor = new CholeskyDecomposition(this.covarianceMatrix, overwrite: false).LowerTriangularFactor;
flag = this.factor.EstimateConditionNumber() < 10000.0;
}
catch (MatrixNotPositiveDefiniteException)
{
}
catch (MatrixSingularException)
{
}
if (!flag)
{
SymmetricEigenvalueDecomposition symmetricEigenvalueDecomposition;
symmetricEigenvalueDecomposition = new SymmetricEigenvalueDecomposition(this.covarianceMatrix);
this.factor = symmetricEigenvalueDecomposition.Eigenvectors.ScaleColumns(Vector.Sqrt(symmetricEigenvalueDecomposition.Eigenvalues));
}
}
this.scale = 0.0 - (0.5 * (double)this.mean.Length * 1.8378770664093456 + Math.Log(Math.Abs(this.factor.GetDeterminant())));
}
protected override void FillRandomVariateCore(System.Random random, Vector sample)
{
if (random == null)
{
ThrowException.ArgumentNull("random");
}
if (sample == null)
{
ThrowException.ArgumentNull("sample");
}
this.DoPrecomputations();
NormalDistribution.Standard.GetRandomVariates(random, sample);
sample.Multiply(this.factor, MatrixOperationSide.Left);
}
public static MultivariateNormalDistribution CreateStandard(int order)
{
return new MultivariateNormalDistribution(order);
}
public MultivariateNormalDistribution(int order)
: this(new ConstantVector(order, 0.0))
{
}
public MultivariateNormalDistribution(Vector mean)
: base(Vector.SafeLength(mean))
{
if (mean == null)
{
ThrowException.ArgumentNull("mean");
}
if (mean.Length == 0)
{
ThrowException.ArgumentOutOfRange("mean");
}
this.mean = mean;
this.covarianceMatrix = Matrix.CreateSymmetric(mean.Length);
this.covarianceMatrix.GetDiagonal().SetValue(1.0);
}
public MultivariateNormalDistribution(Vector mean, SymmetricMatrix covarianceMatrix)
: base(Vector.SafeLength(mean))
{
if (mean == null)
{
ThrowException.ArgumentNull("mean");
}
if (covarianceMatrix == null)
{
ThrowException.ArgumentNull("covarianceMatrix");
}
if (mean.Length == 0)
{
ThrowException.ArgumentOutOfRange("mean");
}
if (mean.Length != covarianceMatrix.RowCount)
{
throw new DimensionMismatchException(DimensionType.Row, "covarianceMatrix", DimensionType.Length, "mean");
}
this.mean = mean;
this.covarianceMatrix = covarianceMatrix;
}
public MultivariateNormalDistribution(Matrix data)
: base((!(data == null)) ? data.ColumnCount : 0)
{
if (data == null)
{
ThrowException.ArgumentNull("variable");
}
this.mean = data.ApplyToColumns(Stats.Mean);
if (data.RowCount <= 1)
{
this.covarianceMatrix = Matrix.CreateSymmetric(data.ColumnCount);
this.covarianceMatrix.GetDiagonal().SetValue(1.0);
}
else
{
this.covarianceMatrix = Stats.CovarianceMatrix(data);
}
}
public MultivariateNormalDistribution(NumericalVariable[] variables)
: base((variables != null) ? variables.Length : 0)
{
if (variables == null)
{
ThrowException.ArgumentNull("variable");
}
for (int i = 0; i < variables.Length; i++)
{
if (variables[i] == null)
{
throw new ArgumentException();
}
}
if (variables.Length == 0)
{
ThrowException.ArgumentOutOfRange("variables");
}
_ = variables[0].Length;
this.mean = new GeneralVector(Array.ConvertAll(variables, Stats.Mean), reuseComponentArray: true);
this.covarianceMatrix = Stats.CovarianceMatrix(variables);
}
public override double LogProbabilityDensityFunction(Vector x)
{
if (x == null)
{
ThrowException.ArgumentNull("x");
}
this.DoPrecomputations();
GeneralVector generalVector;
generalVector = (x - this.mean).AsGeneralVector();
this.factor.Solve(generalVector, overwrite: true);
return this.scale - 0.5 * generalVector.NormSquared();
}
public override Vector GetMeans()
{
return this.mean;
}
public override SymmetricMatrix GetVarianceCovarianceMatrix()
{
return this.covarianceMatrix;
}
}
如果对您有帮忙,非常感谢您支持一下创造者的付出!
感谢支持技术分享,请扫码点赞支持:
技术合作交流qq:2401315930