广义线性回归
算法介绍:
与线性回归假设输出服从高斯分布不同,广义线性模型(GLMs)指定线性模型的因变量 服从指数型分布。Spark的GeneralizedLinearRegression接口允许指定GLMs包括线性回归、泊松回归、逻辑回归等来处理多种预测问题。目前 spark.ml仅支持指数型分布家族中的一部分类型,如下:
家族 | 因变量类型 | 支持类型 |
高斯 | 连续型 | Identity*, Log, Inverse |
二项 | 二值型 | Logit*, Probit, CLogLog |
泊松 | 计数型 | Log*, Identity, Sqrt |
伽马 | 连续型 | Inverse*, Idenity, Log |
*注意目前Spark在 GeneralizedLinearRegression
仅支持最多4096个特征,如果特征超过4096个将会引发异常。对于线性回归和逻辑回归,如果模型特征数量会不断增长,则可通过 LinearRegression 和LogisticRegression来训练。
GLMs要求的指数型分布可以为正则或者自然形式。自然指数型分布为如下形式:
其中 是强度参数,
是分散度参数。在GLM中响应变量
服从自然指数族分布:
其中强度参数 与响应变量
的期望值联系如下:
其中 由所选择的分布形式所决定。GLMs同样允许指定连接函数,连接函数决定了响应变量期望值与线性预测器之间的关系:
通常,连接函数选择如 ,在强度参数与线性预测器之间产生一个简单的关系。这种情况下,连接函数也称为正则连接函数:
GLM通过最大化似然函数来求得回归系数:
其中强度参数和回归系数的联系如下:
Spark的GeneralizedLinearRegression接口提供汇总统计来诊断GLM模型的拟合程度,包括残差、p值、残差、Akaike信息准则及其它。
参数:
family:
类型:字符串型。
含义:模型中使用的误差分布类型。
featuresCol:
类型:字符串型。
含义:特征列名。
fitIntercept:
类型:布尔型。
含义:是否训练拦截对象。
labelCol:
类型:字符串型。
含义:标签列名。
link:
类型:字符串型。
含义:连接函数名,描述线性预测器和分布函数均值之间关系。
linkPredictiongCol:
类型:字符串型。
含义:连接函数(线性预测器列名)。
maxIter:
类型:整数型。
含义:最多迭代次数(>=0)。
predictionCol:
类型:字符串型。
含义:预测结果列名。
regParam:
类型:双精度型。
含义:正则化参数(>=0)。
solver:
类型:字符串型。
含义:优化的求解算法。
tol:
类型:双精度型。
含义:迭代算法的收敛性。
weightCol:
类型:字符串型。
含义:列权重。
调用:
Scala:
import org.apache.spark.ml.regression.GeneralizedLinearRegression
// Load training data
val dataset = spark.read.format("libsvm")
.load("data/mllib/sample_linear_regression_data.txt")
val glr = new GeneralizedLinearRegression()
.setFamily("gaussian")
.setLink("identity")
.setMaxIter(10)
.setRegParam(0.3)
// Fit the model
val model = glr.fit(dataset)
// Print the coefficients and intercept for generalized linear regression model
println(s"Coefficients: ${model.coefficients}")
println(s"Intercept: ${model.intercept}")
// Summarize the model over the training set and print out some metrics
val summary = model.summary
println(s"Coefficient Standard Errors: ${summary.coefficientStandardErrors.mkString(",")}")
println(s"T Values: ${summary.tValues.mkString(",")}")
println(s"P Values: ${summary.pValues.mkString(",")}")
println(s"Dispersion: ${summary.dispersion}")
println(s"Null Deviance: ${summary.nullDeviance}")
println(s"Residual Degree Of Freedom Null: ${summary.residualDegreeOfFreedomNull}")
println(s"Deviance: ${summary.deviance}")
println(s"Residual Degree Of Freedom: ${summary.residualDegreeOfFreedom}")
println(s"AIC: ${summary.aic}")
println("Deviance Residuals: ")
summary.residuals().show()
Java:
import java.util.Arrays;
import org.apache.spark.ml.regression.GeneralizedLinearRegression;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// Load training data
Dataset<Row> dataset = spark.read().format("libsvm")
.load("data/mllib/sample_linear_regression_data.txt");
GeneralizedLinearRegression glr = new GeneralizedLinearRegression()
.setFamily("gaussian")
.setLink("identity")
.setMaxIter(10)
.setRegParam(0.3);
// Fit the model
GeneralizedLinearRegressionModel model = glr.fit(dataset);
// Print the coefficients and intercept for generalized linear regression model
System.out.println("Coefficients: " + model.coefficients());
System.out.println("Intercept: " + model.intercept());
// Summarize the model over the training set and print out some metrics
GeneralizedLinearRegressionTrainingSummary summary = model.summary();
System.out.println("Coefficient Standard Errors: "
+ Arrays.toString(summary.coefficientStandardErrors()));
System.out.println("T Values: " + Arrays.toString(summary.tValues()));
System.out.println("P Values: " + Arrays.toString(summary.pValues()));
System.out.println("Dispersion: " + summary.dispersion());
System.out.println("Null Deviance: " + summary.nullDeviance());
System.out.println("Residual Degree Of Freedom Null: " + summary.residualDegreeOfFreedomNull());
System.out.println("Deviance: " + summary.deviance());
System.out.println("Residual Degree Of Freedom: " + summary.residualDegreeOfFreedom());
System.out.println("AIC: " + summary.aic());
System.out.println("Deviance Residuals: ");
summary.residuals().show();
Python:
from pyspark.ml.regression import GeneralizedLinearRegression
# Load training data
dataset = spark.read.format("libsvm")\
.load("data/mllib/sample_linear_regression_data.txt")
glr = GeneralizedLinearRegression(family="gaussian", link="identity", maxIter=10, regParam=0.3)
# Fit the model
model = glr.fit(dataset)
# Print the coefficients and intercept for generalized linear regression model
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
# Summarize the model over the training set and print out some metrics
summary = model.summary
print("Coefficient Standard Errors: " + str(summary.coefficientStandardErrors))
print("T Values: " + str(summary.tValues))
print("P Values: " + str(summary.pValues))
print("Dispersion: " + str(summary.dispersion))
print("Null Deviance: " + str(summary.nullDeviance))
print("Residual Degree Of Freedom Null: " + str(summary.residualDegreeOfFreedomNull))
print("Deviance: " + str(summary.deviance))
print("Residual Degree Of Freedom: " + str(summary.residualDegreeOfFreedom))
print("AIC: " + str(summary.aic))
print("Deviance Residuals: ")
summary.residuals().show()