PSO的简单java实现(面向对象)

 这是Main.java

import PSO.PSO;

public class M {
    public static void main(String[] args) {
        PSO a = new PSO(-10, 10, -10, 10, 50, 500);
        a.printAnswer();
    }
}

这是eachPoint.java

package PSO;

import java.lang.Math;
/**
 * eachPoint
 */
public class eachPoint {
    static double x_l, x_r, y_l, y_r;
    eachPoint(){}
    eachPoint(double x_a, double x_b, double y_a, double y_b){
        x_l = x_a;
        x_r = x_b;
        y_l = y_a;
        y_r = y_b;
    }
    //generate random vector
    public Point pos = new Point(Math.random()*(x_r - x_l) + x_l, Math.random()*(y_r - y_l) + y_l);//now position
    public Point vec = new Point(Math.random(), Math.random());//now speed vector
    public Point pBest = new Point();//(x,y) which makes f(x,y) the min
    public double preBestVal;//min value of f(x,y) which has been previously calculated
}

这是Point.java

package PSO;

public class Point {
    public double x, y;
    Point(){
        x = y = 0;
    }
    Point(double x, double y){
        this.x = x;
        this.y = y;
    }
}

这是PSO.java

package PSO;
import java.lang.Math;


public class PSO {
    private double x_left, x_right;//x range of Interval
    private double y_left, y_right;//y range of Interval
    private Point globalBest;
    //set of points
    private eachPoint[] pointSets = new eachPoint[10000];
    private int steps;
    private int pointNumbers;
    public double c1=2.0, c2=2.0;//speed up coefficients
    public double r1, r2;//random generated coefficients
    public double w=0.1;//inertia weight
    
    public PSO(double x_a, double x_b, double y_a, double y_b, int pointNumbers, int step){//initialize interval [a,b] and steps; start to train
        changeVariabls(x_a, x_b, y_a, y_b, pointNumbers, step);

        for(int i = 1; i <= steps; i++){
            pointSets[i] = new eachPoint(x_left, x_right, y_left, y_right);
        }
        try {
            run();//start to run
        } catch (Exception e) {
            System.out.println("failed");
            System.out.println(e.getMessage());
        }
    
    }
    private double f(double x1, double x2){//function f
        return (x1 - 1.1)*(x1 - 1.1) + (x2 + 4.4)*(x2 + 4.4) + 1;
    }
    public void printAnswer() {
        System.out.printf("[*]The answer point (x, y) is (%f, %f)\n", globalBest.x, globalBest.y);
    }
    public void changeVariabls(double x_a, double x_b, double y_a, double y_b, int pointNumbers, int step) {
        this.x_left = x_a;
        this.x_right = x_b;
        this.y_left = y_a;
        this.y_right = y_b;
        this.steps = step;
        this.pointNumbers = pointNumbers;
    }

    public void run(){//run the algorithm for {steps} times
        double globalMinVal = 0x3f3f3f3f;//set globalMinVal to infinity
        eachPoint each = new eachPoint();
        for(int i = 1; i <= pointNumbers; i++){//iteration times

            if(i == 1){//run for the first time
                for(int j = 1; j <= steps; j++){//run for each point
                    each = pointSets[j];//choose each
                    each.preBestVal = f(each.pos.x, each.pos.y);
                    each.pBest = each.pos;
                    globalMinVal = each.preBestVal;
                    globalBest = each.pos;//global min position

                    //update pos and speed vector
                    r1 = Math.random();
                    r2 = Math.random();
                    each.vec.x = w*each.vec.x + c1*r1*(each.pBest.x - each.pos.x) + c2*r2*(globalBest.x - each.pos.x);
                    each.vec.y = w*each.vec.y + c1*r1*(each.pBest.y - each.pos.y) + c2*r2*(globalBest.y - each.pos.y);
                    each.pos.x += each.vec.x;
                    each.pos.y += each.vec.y;
                    //make each.pos in the range of [left, right]
                    if(each.pos.x > x_right)  each.pos.x = x_right;
                    if(each.pos.x < x_left)  each.pos.x = x_left;
                    if(each.pos.y > y_right)  each.pos.y = y_right;
                    if(each.pos.y < y_left)  each.pos.y = y_left;
                }
            }else{
                for(int j = 1; j <= steps; j++){//run for each point
                    each = pointSets[j];
                    double nowVal = f(each.pos.x, each.pos.y);
                    if(nowVal < each.preBestVal){
                        each.pBest = each.pos;
                        each.preBestVal = nowVal;
                    }
                    if(globalMinVal > each.preBestVal){
                        globalMinVal = each.preBestVal;
                        globalBest = each.pos;
                    }
                    //update pos and speed vector
                    r1 = Math.random();
                    r2 = Math.random();
                    each.vec.x = w*each.vec.x + c1*r1*(each.pBest.x - each.pos.x) + c2*r2*(globalBest.x - each.pos.x);
                    each.vec.y = w*each.vec.y + c1*r1*(each.pBest.y - each.pos.y) + c2*r2*(globalBest.y - each.pos.y);
                    each.pos.x += each.vec.x;
                    each.pos.y += each.vec.y;
                    //make each.pos in the range of [left, right]
                    if(each.pos.x > x_right)  each.pos.x = x_right;
                    if(each.pos.x < x_left)  each.pos.x = x_left;
                    if(each.pos.y > y_right)  each.pos.y = y_right;
                    if(each.pos.y < y_left)  each.pos.y = y_left;
                }
            }
        }
    }
}

文件目录如下:

PSO:

    eachPoint.java

    Point.java

    PSO.java

Main.java

最后,对于f(x, y) = (x-1.5)^2 + y^2,确实收敛到了最小值点(1.5, 0)

 

第一次用面向对象写Java,不得不说,代码量是真的大(相比C++)

写的不好,请多多关照

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值