关于豆豆的样本生成发生了一点改变,之前没有引入截距b,因此之前的豆豆样本生成函数是近似过原点的;而加入热截距概念,问题的维度就变高了一维。原来使用的斜率,现在就成了真正的“梯度”。与上一篇文章相比,这此由于截距b的加入,豆豆的样本不一样了,应该说更具有普世性了。原本只对w进行斜率下降,而这次海油对截距b进行下降,因为只有当w和b同时取最合适的值时,才能使得拟合最合适,损失函数最小,到达谷底。
豆豆类
public class Bean {
private double xs; //豆豆大小
private double ys; //豆豆毒性
public Bean() {
}
public Bean(double xs, double ys) {
this.xs = xs;
this.ys = ys;
}
public void setXs(double xs) {
this.xs = xs;
}
public void setYs(double ys) {
this.ys = ys;
}
public double getXs() {
return xs;
}
public double getYs() {
return ys;
}
@Override
public String toString() {
return "Bean{" +
"xs=" + xs +
", ys=" + ys +
'}';
}
}
豆豆服务类
public class BeanServe {
public static final double K = 1.0; //斜率
public static final double B = 0.5; //截距
/**
* 造豆豆
*/
public static ArrayList<Bean> creatBeans(int num) {
ArrayList<Bean> list = new ArrayList<>();
for (int i = 0; i < num; i++) {
double xs = Math.random();
double ys = K * xs + B + (0.5 - Math.random()) / 5;
Bean bean = new Bean(xs, ys);
list.add(bean);
}
return list;
}
}
梯度下降
public class Main {
public static void main(String[] args) {
/*
豆豆生成
*/
int num = 100;
ArrayList beans = BeanServe.creatBeans(num);
System.out.println(beans);
/*
初始化参数
*/
double w = 0.1; //k为斜率
double b = 0.1; //b为截距
double alpha = 0.01; //控制梯度下降的速度
/*
梯度下降过程
*/
for (int i = 0; i < 500; i++) {
for (int j = 0; j < num; j++) {
Bean bean = beans.get(j);
double xs = bean.getXs();
double ys = bean.getYs();
double dw = 2 * xs * xs * w + (2 * xs * b - 2 * xs * ys);
double db = 2 * b + (2 * xs * w - 2 * ys);
w = w - alpha * dw;
b = b - alpha * db;
}
}
System.out.println("经过梯度下降的w最终结果是:" + w);
System.out.println("经过梯度下降的b最终结果是:" + b);
}
}