在分类问题中,你要预测的变量y是离散的值,适合使用逻辑回归 (Logistic Regression) 算法。
通常从最简单的二元分类问题开始解决,即“是不是”的问题。
当 时,预测 y=1;
当 时,预测 y=0;
由此,MINST手写数字识别问题可以被分解为:
是数字0的概率是多少;
是数字1的概率是多少;
...
是数字9的概率是多少;
但是,线性回归模型,因为其预测的值可以超越[0,1]的范围,并不适合解决这样的问题。
因此引入一个新的模型,逻辑回归,该模型的输出变量范围始终在0和1之间。
逻辑回归模型的假设是:
其中,X代表特征向量,g代表逻辑函数。
常用的逻辑函数为S形函数(Sigmoid function),公式为:
再看代价函数:
其中
{
if y = 1 :
if y = 0 :
}
因此上式可以合并写成:
y = log(x):(以e为底)
仅看x∈(0, 1]区间部分:
因此 y = -log(x) 在x∈(0, 1]区间时:
因此 y = -log(1-x) 在x∈[0, 1)区间时:
为什么 在逻辑回归中不能再使用
?
因为在线性回归中, ,则对应的代价函数是凸函数,但是在逻辑回归中,
,则对应的代价函数是非凸函数。
如果代价函数是
非凸
如果代价函数是
凸
逻辑回归 Java示例
演示了通过逻辑回归,判断数组是否呈现凸性(即中间数字大于两边数字)。
/** 数据集 */
public class TrainingSet {
public List<Data> dataList = new ArrayList<>();
public void add(double y, double... x) {
Data data = new Data();
data.x = x;
data.y = y;
dataList.add(data);
}
public static class Data {
public double[] x;
public double y;
}
}
/** 逻辑回归例子 */
public class LogisticRegression {
public static void main(String[] args) {
LogisticRegression lr = new LogisticRegression();
lr.study();
lr.test();
}
/** 产生模拟数据
* 该模拟数据符合如下凸特性:即中间数字大于两边数字
* 例如:x1=0.3, x2=0.4, x3=0.1 为凸,因此 y = 1
* 例如:x1=0.5, x2=0.4, x3=0.1 为非凸,因此 y = 0
* */
TrainingSet mock() {
TrainingSet ts = new TrainingSet();
final int size = 200;
//生成凸数据
for(int i = 0; i < size; i++) {
double middle = Math.random();
double x1 = middle * Math.random();
double x3 = middle * Math.random();
ts.add(1, x1, middle, x3);
}
//生成非凸数据
for(int i = 0; i < size; i++) {
double x1 = Math.random();
double x3 = Math.random();
double min = Math.min(x1, x3);
double middle = min * Math.random();
ts.add(0, x1, middle, x3);
}
return ts;
}
double[] θ = new double[]{0,0,0};
final double alpha = 0.001;
/** 学习 */
void study() {
TrainingSet ts = mock();
double lastCost = Double.MAX_VALUE;
int limitTry = 1000;
while(--limitTry > 0) {
double cost = calculateCost(ts);
if(cost > lastCost) {
break;
}
lastCost = cost;
double[] dθs = new double[θ.length];//要同步更新
for(int i = 0; i < θ.length; i++) {
dθs[i] = calculateDelta(ts, i);
}
for(int i = 0; i < θ.length; i++) {
θ[i] -= alpha * dθs[i];
}
}
//System.out.println("limitTry剩余" + limitTry);
for(int i = 0; i < θ.length; i++) {
System.out.print("θ"+i+"=" + θ[i] + " ");
}
System.out.println();
}
/** 验证测试 */
void test() {
TrainingSet ts = mock();
int testNum = ts.dataList.size();
int correctNum = 0;
int errorNum = 0;
for(TrainingSet.Data data : ts.dataList) {
//System.out.println("预测值:" + hypothesis(data) + ",实际值:" + data.y);
double hy = hypothesis(data) >= 0.5 ? 1 : 0;
if(hy == data.y) {
correctNum++;
} else {
errorNum++;
}
}
System.out.println("测试数量=" + testNum + ", 正确=" + correctNum + ", 错误=" + errorNum + ", 准确率=" + 100*((double)correctNum/testNum) + "%'");
}
/** 计算代价 */
double calculateCost(TrainingSet ts) {
double sum = 0;
for(TrainingSet.Data data : ts.dataList) {
// if(data.y == 1) {
// sum += -Math.log(hypothesis(data));
// } else if(data.y == 0) {
// sum += -Math.log(1-hypothesis(data));
// }
sum += (-1 * data.y * (Math.log(hypothesis(data)))) - (1-data.y) * Math.log(1-hypothesis(data));//等同于上面注释部分代码
}
sum = sum / ts.dataList.size();
return sum;
}
/** 预测函数 */
double hypothesis(TrainingSet.Data data) {
double θTX = 0;
for(int i = 0; i < data.x.length; i++) {
θTX += θ[i] * data.x[i];
}
double val = 1 / (1 + Math.pow(Math.E, ( -1 * θTX )));
return val;
}
/** 计算代价函数的偏导数
* @param i 对θi求偏导 */
double calculateDelta(TrainingSet ts, int i) {
double sum = 0;
for(TrainingSet.Data data : ts.dataList) {
sum += (hypothesis(data) - data.y) * data.x[i];
}
sum = sum / ts.dataList.size();
return sum;
}
}