Java实现基于Socket的负载均衡代理服务器(含六种负载均衡算法)

目录

前言

一、常见负载均衡算法

1.完全轮询算法

2.加权轮询算法

3.完全随机算法

4.加权随机算法

5.余数Hash算法

6.一致性Hash算法

二、代码实现

1.项目结构

2.代码实现

总结


前言

负载均衡建立在现有网络结构之上,它提供了一种廉价有效透明的方法扩展网络设备服务器的带宽、增加吞吐量、加强网络数据处理能力、提高网络的灵活性和可用性。

负载均衡(Load Balance)其意思就是分摊到多个操作单元上进行执行,例如Web服务器FTP服务器企业关键应用服务器和其它关键任务服务器等,从而共同完成工作任务。

解释来源:百度百科


一、常见负载均衡算法

1.完全轮询算法

完全轮询就如字面意思一样,每个节点轮流处理请求,相当于一个指针在节点数组中循环遍历

2.加权轮询算法

加权轮询在完全轮询的基础上为每个服务器增加了权重,权重高的会处理更多的请求,这解决了不同节点配置不同的情况引发的分配不均问题,相当于节点数组中每个节点有对应权重数量的个数,指针循环遍历,可以理解为在线性区间内增加节点的线段长

3.完全随机算法

请求随机打到一个节点上进行处理,相当于在节点数组的长度内产生一个随机数,随机数所对应索引的节点为被选中节点

4.加权随机算法

在完全随机算法的基础上为节点配置权重,与加权轮询的原理类似,可认为调整节点在线性区间内的线段长度以增加被选中的概率

5.余数Hash算法

以上算法都不能解决一个缓存和session问题,因为同一个客户端的不同请求几乎每次都是不同的服务来处理的,这样客户的缓存和session等就会丢失,需要重构,所以可以对客户端ip进行Hash处理,之后对服务器数量取余,这样可以客户端的请求是同一个服务器来处理

6.一致性Hash算法

余数Hash虽然解决了客户端的请求到不同服务器的问题,但是当某个或某些服务器下线或上线的情况,几乎会为大多数客户端重新分配服务器,这样导致系统造成不稳定性,一致性Hash定义一个一致性Hash环,为节点计算出Hash值之后放入环内,当客户端请求时,为客户端计算Hash值,在一致性Hash环中沿一个特定的方向寻找离它最近的节点,这样即使遇到服务器上线会下线也不会对大多数服务器产生影响,当然我们可以为节点创建虚拟节点来更均匀合理的将他们分配到Hash环中

二、代码实现

1.项目结构

├── pom.xml
└── src
    ├── main
    │   ├── java
    │   │   ├── Main.java   启动类,ServerSocket绑定监听端口接收请求
    │   │   └── system
    │   │       ├── common
    │   │       │   ├── ConnectUtil.java    连接工具类,可测试是否与指定IP端口连接成功
    │   │       │   └── GetHashCode.java   重计算Hash值工具类
    │   │       ├── configure
    │   │       │   └── Configuration.java   配置类,解析xml配置文件并封装为配置类
    │   │       ├── entity
    │   │       │   └── Server.java   服务器类,包含serverName,ip,port,wight属性
    │   │       ├── random
    │   │       │   ├── BalanceService.java   负载均衡接口,包含获取server,增加server,删除server方法
    │   │       │   └── imp
    │   │       │       ├── ConsistentHashServerImpl.java  一致性Hash负载均衡实现类
    │   │       │       ├── HashServerImpl.java   余数Hash负载均衡实现类
    │   │       │       ├── PollServerImpl.java  完全轮询负载均衡实现类
    │   │       │       ├── RandomServerImpl.java   完全随机负载均衡实现类
    │   │       │       ├── ServerMonitorImpl.java   服务监视器,装饰者模式,为其他实现类增加服务监控、动态增减服务器功能
    │   │       │       ├── WeightPollServerImpl.java   加权轮询负载均衡实现类
    │   │       │       └── WeightRandomServerImpl.java   加权随机负载均衡实现类
    │   │       └── socket
    │   │           └── SocketThread.java   客户端Socket请求线程,每个客户端请求对应一个线程对象,提交到线程池
    │   └── resources
    │       ├── log4j.properties   日志配置文件
    │       └── xw-load-balancing.xml   项目配置文件
    └── test
        └── java

2.代码实现

xml配置文件示例

<?xml version="1.0" encoding="UTF-8"?>
<configuration>
    <servers>
        <server name="sever1" address="127.0.0.1" port="8083" weight="1"/>
        <server name="sever2" address="127.0.0.1" port="8082" weight="2"/>
        <server name="sever3" address="127.0.0.1" port="8081" weight="2"/>
        <server name="sever4" address="127.0.0.1" port="8080" weight="1"/>
    </servers>
    <settings>
        <!--虚拟节点数量-->
        <setting name="vnnNodeCount" value="3"/>
        <!--六种负载均衡方式可选,默认RandomServer-->
        <!--RandomServer-完全随机算法-->
        <!--WeightRandomServer-加权随机算法-->
        <!--PollServer-完全轮询算法-->
        <!--WeightPollServer-加权轮询算法-->
        <!--HashServer-余数Hash-->
        <!--ConsistentHash-一致性Hash-->
        <setting name="random" value="ConsistentHash"/>
        <!--是否打开服务监视器实现服务动态增减-->
        <setting name="openServerMonitor" value="true"/>
        <!--监听端口,默认8088-->
        <setting name="port" value="8088"/>
    </settings>
</configuration>

 Configuration.java

利用Dom4j解析xml配置文件,将配置里面的内容解析,将Server节点封装成Server对象list存储,根据选择的负载均衡算法来注册对应的实现类,如果开启服务监控则使用装饰器加强

package system.configure;

import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import system.entity.Server;
import system.random.BalanceService;
import system.random.imp.*;

import java.io.File;
import java.util.ArrayList;
import java.util.List;

/**
 * 配置类
 *
 * @author xuwei
 * @date 2022/07/18 11:30
 **/
public class Configuration {
    private volatile static Configuration configuration;
    private BalanceService balanceService;
    private Integer port;

    private Configuration(String fileName) {
        File file = new File(fileName);
        if (file.exists()) {
            SAXReader reader = new SAXReader();
            List<Server> serverList = new ArrayList<>();
            Document document = null;
            try {
                document = reader.read(file);
            } catch (DocumentException e) {
                e.printStackTrace();
            }
            Integer vnnNodeCount = 3;
            assert document != null;
            Element root = document.getRootElement();
            List<Element> childElements = root.elements();
            for (Element child : childElements) {
                if (!child.elements().isEmpty()) {
                    for (Element c : child.elements()) {
                        switch (child.getName()) {
                            case "servers":
                                serverList.add(new Server(c.attributeValue("name"), c.attributeValue("address"), Integer.valueOf("".equals(c.attributeValue("port")) ? "80" : c.attributeValue("port")), Integer.valueOf("".equals(c.attributeValue("weight")) ? "0" : c.attributeValue("weight"))));
                                break;
                            case "settings":
                                switch (c.attributeValue("name")) {
                                    case "port":
                                        this.port = Integer.valueOf(c.attributeValue("value") == null ? "8088" : "".equals(c.attributeValue("value")) ? "8088" : c.attributeValue("value"));
                                        break;
                                    case "vnnNodeCount":
                                        vnnNodeCount = c.attributeValue("value") == null ? vnnNodeCount : "".equals(c.attributeValue("value")) ? vnnNodeCount : Integer.valueOf(c.attributeValue("value"));
                                        break;
                                    case "random":
                                        String random = c.attributeValue("value") == null ? "RandomServer" : "".equals(c.attributeValue("value")) ? "RandomServer" : c.attributeValue("value");
                                        switch (random) {
                                            case "WeightRandomServer":
                                                balanceService = new WeightRandomServerImpl(serverList);
                                                break;
                                            case "PollServer":
                                                balanceService = new PollServerImpl(serverList);
                                                break;
                                            case "WeightPollServer":
                                                balanceService = new WeightPollServerImpl(serverList);
                                                break;
                                            case "HashServer":
                                                balanceService = new HashServerImpl(serverList);
                                                break;
                                            case "ConsistentHash":
                                                balanceService = new ConsistentHashServerImpl(serverList, vnnNodeCount);
                                                break;
                                            case "RandomServer":
                                            default:
                                                balanceService = new RandomServerImpl(serverList);
                                                break;
                                        }
                                    case "openServerMonitor":
                                        if ("true".equals(c.attributeValue("value"))) {
                                            balanceService = new ServerMonitorImpl(balanceService);
                                        }
                                        break;
                                    default:
                                        break;
                                }
                                break;
                            default:
                                break;
                        }
                    }
                }
            }
        }
    }

    public static Configuration getConfiguration(String fileName) {
        if (configuration == null) {
            synchronized (Configuration.class) {
                if (configuration == null) {
                    configuration = new Configuration(fileName);
                }
            }
        }
        return configuration;
    }

    public BalanceService getBalanceService() {
        return balanceService;
    }

    public Integer getPort() {
        return port;
    }

}

Server.java

包含serverName、address、port、weight属性,存储节点信息

package system.entity;

/**
 * @ClassName Server
 * @Author xuwei
 * @DATE 2022/4/11
 */
public class Server {
    private String serverName;
    private String address;
    private Integer port;
    private Integer weight;

    public Server() {
    }

    public Server(String serverName, String address, Integer port, Integer weight) {
        this.serverName = serverName;
        this.address = address;
        this.port = port;
        this.weight = weight;
    }

    public String getServerName() {
        return serverName;
    }

    public String getAddress() {
        return address;
    }

    public Integer getPort() {
        return port;
    }

    public Integer getWeight() {
        return weight;
    }

    @Override
    public String toString() {
        return "Server{" +
                "serverName='" + serverName + '\'' +
                ", address='" + address + '\'' +
                ", port=" + port +
                ", weight=" + weight +
                '}';
    }
}

SocketThread.java

Socket处理线程,根据节点信息创建远程连接,将客户端数据与远程服务器数据相互转发,实现代理

package system.socket;

import org.apache.log4j.Logger;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;

/**
 * @ClassName SocketThread
 * @Author xuwei
 * @DATE 2022/4/12
 */
public class SocketThread extends Thread {

    /**
     * 五分钟超时
     */
    public static final int SO_TIME_OUT = 300000;
    private static final int BUFFER_SIZE = 8092;
    private static final Logger log = Logger.getLogger(SocketThread.class);


    private final Socket localSocket;
    private final String remoteHost;
    private final Integer remotePort;
    private Socket remoteSocket;
    private InputStream remoteSocketInputStream;
    private OutputStream localSocketOutputStream;

    public SocketThread(Socket socket, String remoteHost, Integer remotePort) {
        this.localSocket = socket;
        this.remoteHost = remoteHost;
        this.remotePort = remotePort;
    }

    @Override
    public void run() {
        try {
            remoteSocket = new Socket();
            remoteSocket.connect(new InetSocketAddress(remoteHost, remotePort));
            //设置超时,超过时间未收到客户端请求,关闭资源
            //5分钟内无数据传输、关闭链接
            remoteSocket.setSoTimeout(SO_TIME_OUT);
            remoteSocketInputStream = remoteSocket.getInputStream();
            OutputStream remoteSocketOutputStream = remoteSocket.getOutputStream();
            InputStream localSocketInputStream = localSocket.getInputStream();
            localSocketOutputStream = localSocket.getOutputStream();
            new ReadThread().start();
            //写数据,负责读取客户端发送过来的数据,转发给远程
            dataTransmission(localSocketInputStream, remoteSocketOutputStream);
        } catch (Exception e) {
            log.warn(e);
        } finally {
            close();
        }
    }

    private void dataTransmission(InputStream inputStream, OutputStream outputStream) throws IOException {
        byte[] data = new byte[BUFFER_SIZE];
        int len;
        while ((len = inputStream.read(data)) > 0) {
            /*
              读到了缓存大小一致的数据,不需要拷贝,直接使用
              读到了比缓存大小的数据,需要拷贝到新数组然后再使用
             */
            if (len == BUFFER_SIZE) {
                outputStream.write(data);
            } else {
                byte[] dest = new byte[len];
                System.arraycopy(data, 0, dest, 0, len);
                outputStream.write(dest);
            }
        }
    }

    /**
     * 关闭资源
     */
    private void close() {
        try {
            if (remoteSocket != null && !remoteSocket.isClosed()) {
                remoteSocket.close();
                log.info("remoteSocket ---> " + remoteSocket.getRemoteSocketAddress().toString().replace("/", "") + " socket closed");
            }
        } catch (IOException e1) {
            e1.printStackTrace();
        }

        try {
            if (localSocket != null && !localSocket.isClosed()) {
                localSocket.close();
                log.info("localSocket ---> " + localSocket.getRemoteSocketAddress().toString().replace("/", "") + " socket closed");
            }
        } catch (IOException e1) {
            log.warn(e1);
        }
    }

    /**
     * 读数据线程负责读取远程数据后回写到客户端
     */
    class ReadThread extends Thread {
        @Override
        public void run() {
            try {
                dataTransmission(remoteSocketInputStream, localSocketOutputStream);
            } catch (IOException e) {
                log.warn(e);
            } finally {
                close();
            }
        }
    }


}



 ConnectUtil.java

测试服务是否可用

package system.common;

import org.apache.log4j.Logger;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;

/**
 * 连接测试工具类
 *
 * @author xuwei
 * @date 2022/07/18 15:45
 **/
public class ConnectUtil {
    private static final Logger logger = Logger.getLogger(ConnectUtil.class);

    /**
     * 测试telnet 机器端口的连通性
     *
     * @param hostname 地址
     * @param port     端口
     * @param timeout  超时时间
     * @return 是否连通
     */
    public static boolean telnet(String hostname, int port, int timeout) {
        Socket socket = new Socket();
        boolean isConnected = false;
        try {
            socket.connect(new InetSocketAddress(hostname, port), timeout);
            isConnected = socket.isConnected();
        } catch (IOException ignored) {
            logger.warn("Remote server \"" + hostname + ":" + port + "\" connect failed!");
        } finally {
            try {
                socket.close();
            } catch (IOException ignored) {
                isConnected = false;
                logger.warn("Remote server \"" + hostname + ":" + port + "\" connect failed!");
            }
        }
        return isConnected;
    }
}

GetHashCode.java

重新计算Hash值

package system.common;

/**
 * Hash重计算
 *
 * @author xuwei
 * @date 2022/07/18 11:30
 **/
public class GetHashCode {
    private static final long FNV_32_INIT = 2166136261L;
    private static final int FNV_32_PRIME = 16777619;

    public static int getHashCode(String origin) {

        int hash = (int) FNV_32_INIT;
        for (int i = 0; i < origin.length(); i++) {
            hash = (hash ^ origin.charAt(i)) * FNV_32_PRIME;
        }
        hash += hash << 13;
        hash ^= hash >> 7;
        hash += hash << 3;
        hash ^= hash >> 17;
        hash += hash << 5;
        hash = Math.abs(hash);
        return hash;
    }
}

 BalanceService.java

负载均衡接口,包含getServer、addSErver、delServer方法

package system.random;

import system.entity.Server;

/**
 * 负载均衡接口
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public interface BalanceService {
    /**
     * 获取服务器
     *
     * @param requestNumber  请求量
     * @param requestAddress 请求地址
     * @return
     */
    Server getServer(int requestNumber, String requestAddress);

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    void addServerNode(Server server);

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    void delServerNode(Server server);
}

RandomServerImpl.java

完全随机负载均衡实现类,将Server放到list中模拟线性区间,生成伪随机数来模拟砸中某个线性区间的场景

package system.random.imp;

import org.apache.log4j.Logger;
import system.entity.Server;
import system.random.BalanceService;

import java.util.Collections;
import java.util.List;
import java.util.Random;

/**
 * 完全随机实现类
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public class RandomServerImpl implements BalanceService {

    private static final Logger logger = Logger.getLogger(RandomServerImpl.class);
    /**
     * 服务器列表
     */
    private final List<Server> serverList;
    /**
     * 伪随机数生成器
     */
    private final Random random = new Random();

    public RandomServerImpl(List<Server> serverList) {
        this.serverList = Collections.synchronizedList(serverList);
    }

    /**
     * 获取服务器
     *
     * @param requestNumber
     * @param requestAddress
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        if (serverList.isEmpty()) {
            logger.warn("Don not have server available!");
            return null;
        }
        server = serverList.get(random.nextInt(serverList.size()));
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        serverList.add(server);
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        serverList.removeIf(server1 -> server1.getAddress().equals(server.getAddress()) && server1.getPort().equals(server.getPort()));
    }


}

WeightRandomServerImpl.java

加权随机负载均衡实现类,在list中放入对应权重数量的server来模拟线性区间内线段增长,以实现概率增加

package system.random.imp;

import org.apache.log4j.Logger;
import system.entity.Server;
import system.random.BalanceService;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

/**
 * 加权随机实现类
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public class WeightRandomServerImpl implements BalanceService {
    private static final Logger logger = Logger.getLogger(WeightRandomServerImpl.class);
    /**
     * 服务器列表
     */
    private final List<Server> serverList;
    /**
     * 伪随机数生成器
     */
    private final Random random = new Random();

    public WeightRandomServerImpl(List<Server> serverList) {
        List<Server> servers = new ArrayList<>();
        for (Server server : serverList) {
            for (int i = 0; i < server.getWeight(); i++) {
                servers.add(server);
            }
        }
        this.serverList = Collections.synchronizedList(servers);
    }

    /**
     * 获取服务器
     *
     * @param requestNumber
     * @param requestAddress
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        if (serverList.isEmpty()) {
            logger.warn("Don not have server available!");
            return null;
        }
        server = serverList.get(random.nextInt(serverList.size()));
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        for (int i = 0; i < server.getWeight(); i++) {
            serverList.add(server);
        }
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        serverList.removeIf(server1 -> server1.getAddress().equals(server.getAddress()) && server1.getPort().equals(server.getPort()));
    }


}

PollServerImpl.java

完全轮询负载均衡实现类,指针循环遍历list

package system.random.imp;

import org.apache.log4j.Logger;
import system.entity.Server;
import system.random.BalanceService;

import java.util.Collections;
import java.util.List;

/**
 * 简单轮询实现类
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public class PollServerImpl implements BalanceService {
    private static final Logger logger = Logger.getLogger(PollServerImpl.class);
    /**
     * 服务器列表
     */
    private final List<Server> serverList;

    public PollServerImpl(List<Server> serverList) {
        this.serverList = Collections.synchronizedList(serverList);
    }

    /**
     * 获取服务器
     *
     * @param requestNumber
     * @param requestAddress
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        if (serverList.isEmpty()) {
            logger.warn("Don not have server available!");
            return null;
        }
        server = serverList.get(requestNumber % serverList.size());
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        serverList.add(server);
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        serverList.removeIf(server1 -> server1.getAddress().equals(server.getAddress()) && server1.getPort().equals(server.getPort()));
    }


}

WeightPollServerImpl.java

加权轮询负载均衡实现类,在完全轮询基础上将服务器在list中的数量增加为权重对应数量

package system.random.imp;

import org.apache.log4j.Logger;
import system.entity.Server;
import system.random.BalanceService;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * 加权轮询实现类
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public class WeightPollServerImpl implements BalanceService {

    private static final Logger logger = Logger.getLogger(WeightPollServerImpl.class);
    /**
     * 服务器列表
     */
    private final List<Server> serverList;

    public WeightPollServerImpl(List<Server> serverList) {
        List<Server> servers = new ArrayList<>();
        for (Server server : serverList) {
            for (int i = 0; i < server.getWeight(); i++) {
                servers.add(server);
            }
        }
        this.serverList = Collections.synchronizedList(servers);
    }

    /**
     * 获取服务器
     *
     * @param requestNumber
     * @param requestAddress
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        if (serverList.isEmpty()) {
            logger.warn("Don not have server available!");
            return null;
        }
        server = serverList.get(requestNumber % serverList.size());
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        for (int i = 0; i < server.getWeight(); i++) {
            serverList.add(server);
        }
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        serverList.removeIf(server1 -> server1.getAddress().equals(server.getAddress()) && server1.getPort().equals(server.getPort()));
    }


}

HashServerImpl.java

余数Hash负载均衡实现类,对客户端ip进行Hash运算,对服务节点数量取余来获取相应节点

package system.random.imp;

import org.apache.log4j.Logger;
import system.common.GetHashCode;
import system.entity.Server;
import system.random.BalanceService;

import java.util.Collections;
import java.util.List;

/**
 * 余数Hash实现类
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public class HashServerImpl implements BalanceService {
    private static final Logger logger = Logger.getLogger(HashServerImpl.class);
    /**
     * 服务器列表
     */
    private final List<Server> serverList;

    public HashServerImpl(List<Server> serverList) {
        this.serverList = Collections.synchronizedList(serverList);
    }

    /**
     * 获取服务器
     * hash直接取余法
     *
     * @param requestNumber
     * @param requestAddress
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        if (serverList.isEmpty()) {
            logger.warn("Don not have server available!");
            return null;
        }
        server = serverList.get(GetHashCode.getHashCode(requestAddress) % serverList.size());
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        serverList.add(server);
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        serverList.removeIf(server1 -> server1.getAddress().equals(server.getAddress()) && server1.getPort().equals(server.getPort()));
    }


}

ConsistentHashServerImpl.java

一致性Hash负载均衡实现类,可配置虚拟节点数量,使用TreeMap模拟一致性Hash环,客户端连接到达后计算ip的Hash之后去环内找它后一个节点,如果没有则找第一个节点

package system.random.imp;

import org.apache.log4j.Logger;
import system.common.GetHashCode;
import system.entity.Server;
import system.random.BalanceService;

import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * 一致性Hash实现类
 *
 * @author xuwei
 * @date 2022/07/18 10:41
 **/
public class ConsistentHashServerImpl implements BalanceService {
    private static final Logger logger = Logger.getLogger(ConsistentHashServerImpl.class);
    /**
     * 虚拟节点数
     */
    private final Integer vnnNodeCount;
    /**
     * 一致性hash环
     */
    private final TreeMap<Integer, Server> treeMapHash;


    public ConsistentHashServerImpl(List<Server> serverList, Integer vnnNodeCount) {
        this.vnnNodeCount = vnnNodeCount;
        TreeMap<Integer, Server> treeMapHash = new TreeMap<>();
        for (Server server : serverList) {
            int hash = GetHashCode.getHashCode(server.getAddress() + server.getPort());
            treeMapHash.put(hash, server);
            for (int i = 1; i <= this.vnnNodeCount; i++) {
                treeMapHash.put(GetHashCode.getHashCode(server.getAddress() + server.getPort() + "&&" + i), server);
            }
        }
        this.treeMapHash = treeMapHash;
    }

    /**
     * 获取服务器
     *
     * @param requestNumber  请求量
     * @param requestAddress 请求地址
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        synchronized (treeMapHash) {
            if (treeMapHash.isEmpty()) {
                logger.warn("Don not have server available!");
                return null;
            }
            int hash = GetHashCode.getHashCode(requestAddress);
            // 向右寻找第一个 key
            Map.Entry<Integer, Server> subEntry = treeMapHash.ceilingEntry(hash);
            // 设置成一个环,如果超过尾部,则取第一个点
            subEntry = subEntry == null ? treeMapHash.firstEntry() : subEntry;
            server = subEntry.getValue();
        }
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        synchronized (treeMapHash) {
            int hash = GetHashCode.getHashCode(server.getAddress());
            treeMapHash.put(hash, server);
            for (int i = 1; i <= vnnNodeCount; i++) {
                int vnnNodeHash = GetHashCode.getHashCode(server.getAddress() + server.getPort() + "&&" + i);
                treeMapHash.put(vnnNodeHash, server);
            }
        }
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        synchronized (treeMapHash) {
            int hash = GetHashCode.getHashCode(server.getAddress() + server.getPort());
            treeMapHash.remove(hash);
            for (int i = 1; i <= vnnNodeCount; i++) {
                int vnnNodeHash = GetHashCode.getHashCode(server.getAddress() + server.getPort() + "&&" + i);
                treeMapHash.remove(vnnNodeHash);
            }
        }
    }


}

ServerMonitorImpl.java

服务监视器,每次获取server都会检测能否连通,连接失败则将节点移除并放入失败列表中,每三秒对列表中服务器重试,如果连接成功将节点放回并在失败列表中删除此服务

package system.random.imp;

import org.apache.log4j.Logger;
import system.common.ConnectUtil;
import system.entity.Server;
import system.random.BalanceService;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.locks.LockSupport;

/**
 * 装饰器,实现服务监控,动态增减服务
 *
 * @author xuwei
 * @date 2022/07/27 16:58
 **/
public class ServerMonitorImpl implements BalanceService {
    private static final Logger logger = Logger.getLogger(ServerMonitorImpl.class);
    private final BalanceService balanceService;
    /**
     * 连接失败服务器列表
     */
    private final List<Server> failServer = Collections.synchronizedList(new LinkedList<>());
    private final Thread serverMonitor;

    public ServerMonitorImpl(BalanceService balanceService) {
        this.balanceService = balanceService;
        Runnable runnable = () -> {
            logger.info("Server Monitor start!");
            while (true) {
                LockSupport.parkNanos(1000 * 1000 * 1000 * 3L);
                if (Thread.currentThread().isInterrupted()) {
                    logger.info("Server Monitor stop!");
                    return;
                }
                //对错误服务列表一直监控
                failServer.removeIf(server -> {
                    if (ConnectUtil.telnet(server.getAddress(), server.getPort(), 200)) {
                        addServerNode(server);
                        return true;
                    }
                    return false;
                });
            }
        };
        this.serverMonitor = new Thread(runnable);
        this.serverMonitor.setName("server-monitor");
        this.serverMonitor.start();
    }

    /**
     * 获取服务器
     *
     * @param requestNumber  请求量
     * @param requestAddress 请求地址
     * @return
     */
    @Override
    public Server getServer(int requestNumber, String requestAddress) {
        Server server;
        while (true) {
            Server server1 = balanceService.getServer(requestNumber, requestAddress);
            if (server1 == null) {
                this.serverMonitor.interrupt();
                return null;
            }
            // 测试连接
            boolean isConnected = ConnectUtil.telnet(server1.getAddress(), server1.getPort(), 200);
            if (isConnected) {
                server = server1;
                break;
            } else {
                //失败则加入到失效服务器列表并删除此节点
                failServer.add(server1);
                delServerNode(server1);
            }
        }
        return server;
    }

    /**
     * 添加服务器节点
     *
     * @param server server
     */
    @Override
    public void addServerNode(Server server) {
        balanceService.addServerNode(server);
    }

    /**
     * 删除服务器节点
     *
     * @param server server
     */
    @Override
    public void delServerNode(Server server) {
        balanceService.delServerNode(server);
    }
}

Main.java

ServerSocket监听端口,处理连接

import org.apache.log4j.Logger;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;
import system.configure.Configuration;
import system.entity.Server;
import system.random.BalanceService;
import system.socket.SocketThread;

import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * @ClassName Main
 * @Author xuwei
 * @DATE 2022/4/11
 */
public class Main {
    public static final int SO_TIME_OUT = 300000;
    private static final Configuration CONFIGURATION = Configuration.getConfiguration("src/main/resources/xw-load-balancing.xml");
    private static final ThreadPoolExecutor THREAD_POOL_EXECUTOR = new ThreadPoolExecutor(5, 10,
            60L, TimeUnit.SECONDS,
            new SynchronousQueue<>(), new CustomizableThreadFactory());
    private static final Logger logger = Logger.getLogger(Main.class);
    private static int requestNumber = 0;

    public static void main(String[] args) {
        BalanceService balanceService = CONFIGURATION.getBalanceService();
        try {
            //启动ServerSocket监听配置文件中的端口
            ServerSocket serverSocket = new ServerSocket(CONFIGURATION.getPort());
            logger.info("The service runs successfully on port " + CONFIGURATION.getPort());
            // 一直监听,接收到新连接,则开启新线程去处理
            while (true) {
                Socket localSocket = serverSocket.accept();
                //判断请求次数是否将要溢出
                requestNumber = requestNumber == Integer.MAX_VALUE ? 0 : ++requestNumber;
                //根据负载均衡算法获取转发服务器
                Server server = balanceService.getServer(requestNumber, localSocket.getInetAddress().getHostAddress());
                if (server == null) {
                    System.exit(0);
                }
                //5分钟内无数据传输、关闭链接
                localSocket.setSoTimeout(SO_TIME_OUT);
                logger.info(localSocket.getRemoteSocketAddress().toString().replace("/", "") + "  connect to server : \"" + server.getServerName() + "\"");
                //启动线程处理本连接
                THREAD_POOL_EXECUTOR.submit(new SocketThread(localSocket, server.getAddress(), server.getPort()));
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

总结

此项目只是一个简单的实践,目前IO模型还需优化

源码地址

https://github.com/CodeXu-cyber/xw-load-balancing.git

  • 7
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CodeXu_cyber

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值