数据
x=(1,2,3,4,5)
y = (1,1.5,3,4.5,5)
算法结果
R语言运行结果
算法原理
x的均值:
xp=sum(x1+x2+x3+…+xn)/n
y的均值 :
yp=sum(y1+y2+y3+…+yn)/n
x的平方差之和:
lxx=sum( (xi-xp) ^ 2 )
协方差之和
lxy=sum( (xi-xp)*(yi-yp) )
拟合直线 y’=kx+b
k=lxy / lxx
b=yp-k*xp
代码实现
(数据容器选用集合,这样可以把数据当向量运算)
集合求和方法
public static double sum(List<Number> c) {
try {
Objects.requireNonNull(c);
}catch (Exception e) {
return Double.NaN;
}
double ret=0;
Iterator<Number> itr=c.iterator();
while(itr.hasNext()) {
ret+=itr.next().doubleValue();
}
return ret;
}
由于计算协方差需要集合内所有数作乘法,与求和一样,都是对连续计算每个元素,不如定义一个连续计算方法。
集合连续计算方法
第一个参数为一个集合
第二个参数为函数接口,入参1为记录值,入参2为待计算值,出参为计算后的值,是下次迭代的入参1。
//Continuous computation
public static double conc(List<Number> c,BiFunction<Number,Number,Number> fun) {
try {
Objects.requireNonNull(c);
}catch (Exception e) {
return Double.NaN;
}
Number ret=null;
Iterator<Number> itr=c.iterator();
while(itr.hasNext()) {
if(ret==null) {
ret=itr.next();
}else{
ret=fun.apply(ret,itr.next().doubleValue());
// System.out.println("");
}
}
return ret.doubleValue();
}
使用方法,(计算方法由函数接口决定):
可以求和:
conc(x,(r,n)->r.doubleValue()+n.doubleValue());
可以求积:
conc(x,(r,n)->r.doubleValue()*n.doubleValue());
对集合内每一个元素进行计算更新
计算 xi-xp
public static BiFunction< List<Number>,Function<Number,Number>,List<Number> >
calc=new BiFunction< List<Number>,Function<Number,Number>,List<Number> >() {
public List<Number> apply(List<Number> a, Function<Number,Number> fun) {
try {
Objects.requireNonNull(a);
Objects.requireNonNull(fun);
}catch (Exception e) {
return null;
}
List<Number> b=new ArrayList();
a.forEach(itme->{
b.add(fun.apply(itme));
});
return b;
}
};
多个集合运算生成新集合
主要计算协方差----两个集合相乘 (xi-xp)*(yi-yp)
public static BiFunction< List<List<Number>>,Function<List<Number>,Number>,List<Number> >
call=new BiFunction< List<List<Number>> ,Function<List<Number>,Number>,List<Number> >() {
public List<Number> apply( List<List<Number>> c, Function<List<Number>,Number> fun) {
try {
Objects.requireNonNull(c);
Objects.requireNonNull(fun);
}catch (Exception e) {
return null;
}
int width=c.size();
List<Number> b=new ArrayList();
int height=c.get(0).size();
for(int h=0;h<height;h++) {
List<Number> tmp=new ArrayList();
for(int w=0;w<width;w++) {
tmp.add(c.get(w).get(h));
}
b.add(fun.apply(tmp));
}
return b;
}
};
拟合算法
public static double[] lineFit(List<Number> x,List<Number> y){
double xp=mean(x);
List<Number> xi_xp = calc.apply(x,e->e.doubleValue()-xp);
double lxx = sum( calc.apply(xi_xp,e->e.doubleValue()*e.doubleValue()) );
System.out.println(lxx);
double yp=mean(y);
List<Number> yi_yp = calc.apply(y,e->e.doubleValue()-yp);
double lxy = sum( call.apply(Arrays.asList(xi_xp,yi_yp),
e->conc(e,(r,n)->r.doubleValue()*n.doubleValue()) ));
double k=lxy/lxx;
double b=yp-k*xp;
return new double[] {k,b};
}
完整代码
package utility;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import javafx.application.Application;
import javafx.scene.Parent;
import javafx.scene.Scene;
import javafx.scene.chart.LineChart;
import javafx.scene.chart.NumberAxis;
import javafx.scene.chart.XYChart.Data;
import javafx.scene.chart.XYChart.Series;
import javafx.stage.Stage;
public class LineFit extends Application{
public static double sum(List<Number> c) {
try {
Objects.requireNonNull(c);
}catch (Exception e) {
return Double.NaN;
}
double ret=0;
Iterator<Number> itr=c.iterator();
while(itr.hasNext()) {
ret+=itr.next().doubleValue();
}
return ret;
}
//Continuous computation
public static double conc(List<Number> c,BiFunction<Number,Number,Number> fun) {
try {
Objects.requireNonNull(c);
}catch (Exception e) {
return Double.NaN;
}
Number ret=null;
Iterator<Number> itr=c.iterator();
while(itr.hasNext()) {
if(ret==null) {
ret=itr.next();
}else{
ret=fun.apply(ret,itr.next().doubleValue());
}
}
return ret.doubleValue();
}
public static double mean(List<Number> c) {
try {
Objects.requireNonNull(c);
}catch (Exception e) {
return Double.NaN;
}
double ret = c.size()>0? sum(c)/c.size():0;
return ret;
}
public static double[] lineFit(List<Number> x,List<Number> y){
double xp=mean(x);
List<Number> xi_xp = calc.apply(x,e->e.doubleValue()-xp);
double lxx = sum( calc.apply(xi_xp,e->e.doubleValue()*e.doubleValue()) );
double yp=mean(y);
List<Number> yi_yp = calc.apply(y,e->e.doubleValue()-yp);
double lxy = sum( call.apply(Arrays.asList(xi_xp,yi_yp),
e->conc(e,(r,n)->r.doubleValue()*n.doubleValue()) ));
double k=lxy/lxx;
double b=yp-k*xp;
return new double[] {k,b};
}
public static void main(String[] args) {
launch();
}
@Override
public void start(Stage primaryStage) throws Exception {
List<Number> x = Arrays.asList(1,2,3,4,5);
List<Number> y = Arrays.asList(1,1.5,3,4.5,5);
primaryStage.setScene(new Scene(plot(x,y)));
primaryStage.show();
}
public Series<Number, Number> test(List<Number> x,List<Number> y){
double[] kb = lineFit(x,y);
double k = kb[0];
double b = kb[1];
Iterator<Number> xi = x.iterator();
Series<Number, Number> series = new LineChart.Series<Number,Number>();
while(xi.hasNext()) {
Number tmp = xi.next();
series.getData().add(new Data(tmp,k*tmp.doubleValue()+b));
}
series.setName("拟合直线 "+"Y="+k+"x+("+String.format("%.2f",b)+")");
return series;
}
public Series<Number, Number> data(List<Number> x,List<Number> y){
Series<Number, Number> series = new LineChart.Series<Number,Number>();
Iterator<Number> xi = x.iterator();
Iterator<Number> yi = y.iterator();
while(xi.hasNext()&&yi.hasNext()) {
series.getData().add(new Data(xi.next(),yi.next()));
}
series.setName("data");
return series;
}
public Parent plot(List<Number> x,List<Number> y) {
NumberAxis xAxis=new NumberAxis();
NumberAxis yAxis=new NumberAxis();
LineChart chart = new LineChart(xAxis, yAxis);
chart.getData().add(data(x,y));
chart.getData().add(test(x,y));
return chart;
}
public static BiFunction< List<Number>,Function<Number,Number>,List<Number> >
calc=new BiFunction< List<Number>,Function<Number,Number>,List<Number> >() {
public List<Number> apply(List<Number> a, Function<Number,Number> fun) {
try {
Objects.requireNonNull(a);
Objects.requireNonNull(fun);
}catch (Exception e) {
return null;
}
List<Number> b=new ArrayList();
a.forEach(itme->{
b.add(fun.apply(itme));
});
return b;
}
};
public static BiFunction< List<List<Number>>,Function<List<Number>,Number>,List<Number> >
call=new BiFunction< List<List<Number>> ,Function<List<Number>,Number>,List<Number> >() {
public List<Number> apply( List<List<Number>> c, Function<List<Number>,Number> fun) {
try {
Objects.requireNonNull(c);
Objects.requireNonNull(fun);
}catch (Exception e) {
return null;
}
int width=c.size();
List<Number> b=new ArrayList();
int height=c.get(0).size();
for(int h=0;h<height;h++) {
List<Number> tmp=new ArrayList();
for(int w=0;w<width;w++) {
tmp.add(c.get(w).get(h));
}
b.add(fun.apply(tmp));
}
return b;
}
};
}
R语言代码
x=c(1,2,3,4,5)
y = c(1,1.5,3,4.5,5)
data1=data.frame(x=x,y=y)
lm.data1<-lm(y~x,data=data1)
b<-round(lm.data1$coefficients[1],3)
k<-round(lm.data1$coefficients[2],3)
plot(data1$x,data1$y,xlab="x",ylab = "y",col="red",pch="*")
abline(lm.data1,col="blue")
text(mean(data1$x),max(data1$y),paste("y = ",k,"x+(",k,")",sep = ""))