机器学习(2)-- 代价函数的Java代码实现

代价函数主要是通过计算豆豆认知的均方误差e与认知w的关系,从而通过数学方式找到误差最小的点(即w值)。从数学上推理来看(其实就是高中知识,自己可以推),损失函数e与豆豆认知关系是一个抛物线函数关系,而该抛物线的最低点就对应最合适的w,这样就是误差的。上代码
豆豆生成代码
public class Bean {

private double xs;  //豆豆大小
private double ys;  //豆豆毒性
static double w = 0.1; //初始认知

public Bean() {
}

public Bean(double xs, double ys) {
    this.xs = xs;
    this.ys = ys;
}

public void setXs(double xs) {
    this.xs = xs;
}

public void setYs(double ys) {
    this.ys = ys;
}

public double getXs() {
    return xs;
}

public double getYs() {
    return ys;
}


@Override
public String toString() {
    return "Bean{" +
            "xs=" + xs +
            ", ys=" + ys +
            ", w=" + w +
            '}';
}

}
豆豆服务类
public class BeanServe {

/**
 * 造豆豆
 */
public static ArrayList<Bean> creatBeans(int num) {

    ArrayList<Bean> list = new ArrayList<>();
    for (int i = 0; i < num; i++) {
        double xs = Math.random() * 10;
        double ys = xs + Math.random() * (1.0 / 10.0);
        Bean bean = new Bean(xs, ys);
        list.add(bean);
    }
    return list;
}

/**
 * 计算一个豆豆的误差
 */
public static double singleError(Bean bean) {

    double y_pre = Bean.w * bean.getXs();
    double e = bean.getYs() - y_pre;
    return Math.pow(e, 2);
}

/**
 * 一群豆豆的误差
 */
public static double arrError(ArrayList<Bean> list) {

    double d = 0;
    Iterator<Bean> it = list.iterator();
    while (it.hasNext()) {
        double mid = singleError(it.next());
        d += mid;
    }
    return d;
}

}
程序入口
public class Main {

public static void main(String[] args) {

    /*
    豆豆生成
     */
    int num = 100;
    ArrayList<Bean> beans = BeanServe.creatBeans(num);
    System.out.println(beans);

    /*
    整体样本的平均误差
     */
    double sumE = BeanServe.arrError(beans);
    System.out.println(sumE);
    System.out.println(sumE / num);

    /*
    代价函数最低时的w
     */
    double xy = 0;
    double xx = 0;
    Iterator<Bean> it = beans.iterator();
    while (it.hasNext()) {
        Bean bean = it.next();
        double d1 = bean.getXs() * bean.getYs();
        double d2 = Math.pow(bean.getXs(), 2);
        xy += d1;
        xx += d2;
    }

    double w = xy / xx;
    System.out.println(w);
    Bean.w = w;

    /*
    整体样本的平均误差
     */
    double sum = BeanServe.arrError(beans);
    System.out.println(sum);
    System.out.println(sum / num);
}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值