代价函数主要是通过计算豆豆认知的均方误差e与认知w的关系,从而通过数学方式找到误差最小的点(即w值)。从数学上推理来看(其实就是高中知识,自己可以推),损失函数e与豆豆认知关系是一个抛物线函数关系,而该抛物线的最低点就对应最合适的w,这样就是误差的。上代码
豆豆生成代码
public class Bean {
private double xs; //豆豆大小
private double ys; //豆豆毒性
static double w = 0.1; //初始认知
public Bean() {
}
public Bean(double xs, double ys) {
this.xs = xs;
this.ys = ys;
}
public void setXs(double xs) {
this.xs = xs;
}
public void setYs(double ys) {
this.ys = ys;
}
public double getXs() {
return xs;
}
public double getYs() {
return ys;
}
@Override
public String toString() {
return "Bean{" +
"xs=" + xs +
", ys=" + ys +
", w=" + w +
'}';
}
}
豆豆服务类
public class BeanServe {
/**
* 造豆豆
*/
public static ArrayList<Bean> creatBeans(int num) {
ArrayList<Bean> list = new ArrayList<>();
for (int i = 0; i < num; i++) {
double xs = Math.random() * 10;
double ys = xs + Math.random() * (1.0 / 10.0);
Bean bean = new Bean(xs, ys);
list.add(bean);
}
return list;
}
/**
* 计算一个豆豆的误差
*/
public static double singleError(Bean bean) {
double y_pre = Bean.w * bean.getXs();
double e = bean.getYs() - y_pre;
return Math.pow(e, 2);
}
/**
* 一群豆豆的误差
*/
public static double arrError(ArrayList<Bean> list) {
double d = 0;
Iterator<Bean> it = list.iterator();
while (it.hasNext()) {
double mid = singleError(it.next());
d += mid;
}
return d;
}
}
程序入口
public class Main {
public static void main(String[] args) {
/*
豆豆生成
*/
int num = 100;
ArrayList<Bean> beans = BeanServe.creatBeans(num);
System.out.println(beans);
/*
整体样本的平均误差
*/
double sumE = BeanServe.arrError(beans);
System.out.println(sumE);
System.out.println(sumE / num);
/*
代价函数最低时的w
*/
double xy = 0;
double xx = 0;
Iterator<Bean> it = beans.iterator();
while (it.hasNext()) {
Bean bean = it.next();
double d1 = bean.getXs() * bean.getYs();
double d2 = Math.pow(bean.getXs(), 2);
xy += d1;
xx += d2;
}
double w = xy / xx;
System.out.println(w);
Bean.w = w;
/*
整体样本的平均误差
*/
double sum = BeanServe.arrError(beans);
System.out.println(sum);
System.out.println(sum / num);
}
}