梯度是对于曲面来说的,而线的梯度可简单理解为某点处的斜率OK!开始可以简单使用斜率不恰当地来替代梯度吧,方便理解。所以我们暂且把“梯度下降”在这一节叫成“斜率下降”吧,别介意哦!斜率下降就是使得某点的斜率慢慢变化,当斜率为0的地方不就正是抛物线的最低点吗?当然我们在初始斜率变化时要考虑如下问题:
(1)由于变化不能太剧烈,因此我们要引入学习率alpha,这样才不会使得变化太剧烈,收敛慢,甚至产生震荡不收敛。而学习率其实就是个衰减系数,不想让变化太剧烈。
(2)想让在初始的时候降得快一点,后期慢一点,就如同很多元启发算法一样。
我们就试着用随机梯度下降吧,还有一个叫批量梯度下降。这两者不同之处就是最小处理单元不一样,随机梯度下降算法的最小单位是一个样本,而批量梯度下降算法是先把样本分成小组,以小组作为最小单位进行操作。
随机梯度下降代码如下
豆豆类
public class Bean {
private double xs; //豆豆大小
private double ys; //豆豆毒性
static double w = 0.1; //初始认知
public Bean() {
}
public Bean(double xs, double ys) {
this.xs = xs;
this.ys = ys;
}
public void setXs(double xs) {
this.xs = xs;
}
public void setYs(double ys) {
this.ys = ys;
}
public double getXs() {
return xs;
}
public double getYs() {
return ys;
}
@Override
public String toString() {
return "Bean{" +
"xs=" + xs +
", ys=" + ys +
", w=" + w +
'}';
}
}
豆豆服务类
public class BeanServe {
/**
* 造豆豆
*/
public static ArrayList<Bean> creatBeans(int num) {
ArrayList<Bean> list = new ArrayList<>();
for (int i = 0; i < num; i++) {
double xs = Math.random() * 10;
double ys = xs + Math.random() * (1.0 / 10.0);
Bean bean = new Bean(xs, ys);
list.add(bean);
}
return list;
}
/**
* 计算一个豆豆的误差
*/
public static double singleError(Bean bean) {
double y_pre = Bean.w * bean.getXs();
double e = bean.getYs() - y_pre;
return Math.pow(e, 2);
}
/**
* 一群豆豆的误差
*/
public static double arrError(ArrayList<Bean> list) {
double d = 0;
Iterator<Bean> it = list.iterator();
while (it.hasNext()) {
double mid = singleError(it.next());
d += mid;
}
return d;
}
}
梯度下降代码
public class Main {
public static void main(String[] args) {
/*
豆豆生成
*/
int num = 100;
ArrayList beans = BeanServe.creatBeans(num);
System.out.println(beans);
/*
初始化参数
*/
double w = 0.1;
double k; //k为斜率(梯度)
double alpha = 0.01; //控制梯度下降的速度
/*
梯度下降过程
*/
for (int i = 0; i < 10; i++) {
for (int j = 0; j < num; j++) {
Bean bean = beans.get(j);
double xs = bean.getXs();
double ys = bean.getYs();
k = 2 * xs * xs * w - 2 * xs * ys;
System.out.println("斜率为:" + k);
w = w - alpha * k;
System.out.println("w值为:" + w);
}
}
System.out.println("经过梯度下降的w最终结果是:" + w);
}
}