ServerSocketChannel实现多Selector高并发server

本文中的主要代码转载自http://www.cnblogs.com/yueweimian/p/6262211.html,涉及到日志处理相关代码稍有修改和包名有修改,其他代码全部出自转载自上述链接的作者。
原作者的博文开头如下:

参考hbase RpcServer,编写了一个简洁版多Selector
server,对nio怎么用,Selector如何选择事件会有更深入的认识。

client端发送消息:内容长度 + 内容,200线程同时发送

server端接收消息:解析内容长度和内容,返回2MB测试数据给客户端

由于原作者并未在代码中添加详细的注释,但是个人认为原作者编写的这个高并发server,对深入理解非阻塞的ServerSocketChannel的运用还是很用帮助的。故本人在原作者的代码中添加了中文注释,希望把这个好东西一起分享给大家。

非阻塞通信

我们知道ServerSocket和socket,在运行中,常常会阻塞,比如serverSocket的accept()方法,如果没有客户端连接的话就会一直阻塞,还有一些Io阻塞,例如read()方法,如果没有读到足够的字节数,就会一直阻塞。如果一个服务器需要同时处理多个客户端数据的话,就需要开辟多个线程去处理。系统的开销很大。
jdk自1.4开始就提供了java.nio包,提供了很多用于非阻塞通讯的类。非阻塞通讯是基于,通道(channel)和缓冲器(byteBuffer)来协同工作的。
非阻塞模式下的ServerSocketChannel对于读写io操作操作都会快速返回,有多少就处理多少,结合selector(事件选择器),可以用单个线程处理多个客户端连接的事件。

服务器代码

package com.net.program.nio;



import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingQueue;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/**
 * Created by wangkai8 on 17/1/5.
 */
public class Server {

    public static final Logger LOG = LogManager.getLogger();

    //读取来自客户端数据的阻塞队列
    private BlockingQueue<Call> queue = new LinkedBlockingQueue<Call>();

    //处理未写完的任务的阻塞队列
    private Queue<Call> responseCalls = new ConcurrentLinkedQueue<Call>();

    volatile boolean running = true;

    //处理向客户端写任务的对象
    private Responder responder = null;
    //缓冲器读写策略的阀值
    private static int NIO_BUFFER_LIMIT = 64 * 1024;

    private int handler = 10;


    /**
     * 负责服务器的初始化,实例化serverSocketChannel,和selector,包括serverSocketChannel绑定ip地址和端口,
     * 并向通道注册accept事件。启动多个reader线程,用于后续和客户端连接通讯
     * 
     *
     */
    class Listener extends Thread {

        Selector selector;
        Reader[] readers;
        int robin;
        int readNum;

        Listener(int port) throws IOException {
            ServerSocketChannel serverChannel = ServerSocketChannel.open();
            serverChannel.configureBlocking(false);
            serverChannel.socket().bind(new InetSocketAddress(port), 150);
            selector = Selector.open();
            serverChannel.register(selector, SelectionKey.OP_ACCEPT);
            readNum = 10;
            readers = new Reader[readNum];
            for(int i = 0; i < readNum; i++) {
                readers[i] = new Reader(i);
                readers[i].start();
            }
        }


        public void run() {
            while(running) {
                try {
                    selector.select();
                    Iterator<SelectionKey> it = selector.selectedKeys().iterator();
                    while(it.hasNext()) {
                        SelectionKey key = it.next();
                        it.remove();
                        if(key.isValid()) {
                            if(key.isAcceptable()) {
                                doAccept(key);
                            }
                        }
                    }
                } catch (IOException e) {
                    LOG.error("", e);
                }
            }
        }

        //处理客户端连接事件
        public void doAccept(SelectionKey selectionKey) throws IOException {
            ServerSocketChannel serverSocketChannel = (ServerSocketChannel) selectionKey.channel();
            SocketChannel socketChannel;
            while((socketChannel = serverSocketChannel.accept()) != null) {
                try {
                    //设置为非阻塞的模式
                    socketChannel.configureBlocking(false);
                    socketChannel.socket().setTcpNoDelay(true);
                    socketChannel.socket().setKeepAlive(true);
                } catch (IOException e) {
                    socketChannel.close();
                    throw e;
                }
                //取模的方式获取reader对象用于处理后续操作
                Reader reader = getReader();
                try {
                    //添加新的任务时,阻塞当前reader的任务  
                    //这添加同步代码块的原因应该是,为了防止connection对象,没构造完或者还没被添加到附件时,而读事件却触发了、
                    //所以添加同步的机制,当添加新的任务时,阻塞当前读取任务的执行。
                    reader.startAdd();
                    //注册可读事件
                    SelectionKey readKey = reader.registerChannel(socketChannel);
                    //构造connection对象
                    Connection c = new Connection(socketChannel);
                    //添加到该SelectionKey对象的附件中
                    readKey.attach(c);
                } finally {
                    //添加完成任务,唤醒被阻塞reader对象
                    reader.finishAdd();
                }
            }
        }

        public Reader getReader() {
            //防止超过int的最大值
            if(robin == Integer.MAX_VALUE) {
                robin = 0;
            }
            return readers[(robin ++) % readNum];
        }
    }

    //用来处理客户端的可读事件
    class Reader extends Thread {

        Selector readSelector;
        boolean adding;

        Reader(int i) throws IOException {
            setName("Reader-" + i);
            this.readSelector = Selector.open();
            LOG.info("Starting Reader-" + i + "...");
        }

        @Override
        public void run() {
            //循环处理读事件
            while(running) {
                try {
                    //
                    readSelector.select();

                    //这添加同步代码块的原因应该是,为了防止connection对象,没构造完,而读事件却触发了、
                    while(adding) {
                        synchronized(this) {
                            this.wait(1000);
                        }
                    }

                    Iterator<SelectionKey> it = readSelector.selectedKeys().iterator();
                    while(it.hasNext()) {
                        SelectionKey key = it.next();
                        it.remove();
                        if(key.isValid()) {
                            if(key.isReadable()) {
                                doRead(key);
                            }
                        }
                    }
                } catch (IOException e) {
                    LOG.error("", e);
                } catch (InterruptedException e) {
                    LOG.error("", e);
                }
            }
        }

        //具体的读操作方法
        public void doRead(SelectionKey selectionKey) {
            //获取关联的connection对象
            Connection c = (Connection) selectionKey.attachment();
            if(c == null) {
                return;
            }

            int n;
            try {
                n = c.readAndProcess();
            } catch (IOException e) {
                LOG.error("", e);
                n = -1;
            } catch (Exception e) {
                LOG.error("", e);
                n = -1;
            }
            if(n == -1) {
                c.close();
            }
        }

        //socketChannel向readSelector注册可读事件
        public SelectionKey registerChannel(SocketChannel channel) throws IOException {
            return channel.register(readSelector, SelectionKey.OP_READ);
        }

        //添加新的任务时,阻塞当前任务  
        public void startAdd() {
            adding = true;
            readSelector.wakeup();
        }

        public synchronized void finishAdd() {
            adding = false;
            this.notify();
        }
    }


    class Connection {
        private SocketChannel channel;
        //用来保存客户端内容的长度
        private ByteBuffer dataBufferLength;
        //用来读取客户端具体的内容
        private ByteBuffer dataBuffer;
        private boolean skipHeader;

        public Connection(SocketChannel channel) {
            this.channel = channel;
            //为dataBufferLength为4个字节的缓冲器
            this.dataBufferLength = ByteBuffer.allocate(4);
        }

        public int readAndProcess() throws IOException {
            int count;
            //第一次获取,要读的内容的长度,并将值存于dataBufferLength缓冲器中,int值,就4个字节,
            if(!skipHeader) {
                count = channelRead(channel, dataBufferLength);
                if (count < 0 || dataBufferLength.remaining() > 0) {
                    return count;
                }
            }

            skipHeader = true;
            //如果dataBuffer为null就为dataBuffer初始化,分配的字节长度为dataBufferLength读到的值
            if(dataBuffer == null) {
                dataBufferLength.flip();
                int dataLength = dataBufferLength.getInt();
                dataBuffer = ByteBuffer.allocate(dataLength);
            }

            count = channelRead(channel, dataBuffer);
            //如果读取的数据不为0,并且没有剩余的长度,说明读取完成,执行下一步操作
            if(count >= 0 && dataBuffer.remaining() == 0) {
                process();
            }

            return count;
        }


        /**
         * process the dataBuffer
         */
        public void process() {
            dataBuffer.flip();
            byte[] data = dataBuffer.array();
            //构造Call对象,并加入到阻塞队列
            Call call = new Call(this, data, responder);
            try {
                queue.put(call);
            } catch (InterruptedException e) {
                LOG.error("", e);
            }

        }


        public void close() {
            if(channel != null) {
                try {
                    channel.close();
                } catch (IOException e) {
                }
            }
        }
    }

    //向客户端回写数据
    class Responder extends Thread {

        Selector writeSelector;

        public Responder() throws IOException {
            writeSelector = Selector.open();
        }

        public void run() {
            //循环处理写事件
            while(running) {
                try {
                    registWriters();
                    int n = writeSelector.select(1000);
                    if(n == 0) {
                        continue;
                    }
                    Iterator<SelectionKey> it = writeSelector.selectedKeys().iterator();
                    while(it.hasNext()) {
                        SelectionKey key = it.next();
                        it.remove();
                        if(key.isValid() && key.isWritable()) {
                            doAsyncWrite(key);
                        }
                    }
                } catch (IOException e) {
                    LOG.error("", e);
                }
            }
        }

        //从阻塞队列中获取,call对象,并为该对象中的socketChannel注册写事件
        public void registWriters() throws IOException {
            Iterator<Call> it = responseCalls.iterator();
            while(it.hasNext()) {
                Call call = it.next();
                it.remove();
                SelectionKey key = call.conn.channel.keyFor(writeSelector);
                try {
                    if (key == null) {
                        try {
                            //注册写事件,并将call对象作为附件
                            call.conn.channel.register(writeSelector, SelectionKey.OP_WRITE, call);
                        } catch (ClosedChannelException e) {
                            //the client went away
                            if (LOG.isTraceEnabled())
                                LOG.trace("the client went away", e);
                        }
                    } else {
                        key.interestOps(SelectionKey.OP_WRITE);
                    }
                } catch (CancelledKeyException e) {
                    if (LOG.isTraceEnabled())
                        LOG.trace("the client went away", e);
                }
            }
        }

        //将Call对象添加至写的阻塞队列
        public void registerForWrite(Call call) throws IOException {
            responseCalls.add(call);
            writeSelector.wakeup();
        }

        private void doAsyncWrite(SelectionKey key) throws IOException {
            Call call = (Call) key.attachment();
            if(call.conn.channel != key.channel()) {
                throw new IOException("bad channel");
            }
            //将内容写回客户端
            int numBytes = channelWrite(call.conn.channel, call.response);
            if(numBytes < 0 || call.response.remaining() == 0) {
                try {
                    key.interestOps(0);
                } catch (CancelledKeyException e) {
                    LOG.warn("Exception while changing ops : " + e);
                }
            }
        }

        private void doResponse(Call call) throws IOException {
            //if data not fully send, then register the channel for async writer
            //如果数据没有完全发送完,就作为任务加入待发送的阻塞队列
            if(!processResponse(call)) {
                //没写完就加入写会的阻塞队列
                registerForWrite(call);
            }
        }

        //将call的buffer内容写回客户端,如果全部写完,则返回true,如果未全部写完,则返回false,之后未写完的,作为任务加入至阻塞队列。
        private boolean processResponse(Call call) throws IOException {
            boolean error = true;
            try {
                int numBytes = channelWrite(call.conn.channel, call.response);
                if (numBytes < 0) {
                    throw new IOException("error socket write");
                }
                error = false;
            } finally {
                if(error) {
                    call.conn.close();
                }
            }
            if(!call.response.hasRemaining()) {
                call.done = true;
                return true;
            }
            return false;
        }
    }
    //处理queue队列中的任务。
    class Handler extends Thread {

        public Handler(int i) {
            setName("handler-" + i);
            LOG.info("Starting Handler-" + i + "...");
        }

        public void run() {
            while(running) {
                try {
                    Call call = queue.take();
                    process(call);
                } catch (InterruptedException e) {
                    LOG.error("", e);
                } catch (IOException e) {
                    LOG.error("", e);
                }
            }
        }

        public void process(Call call) throws IOException {
            //打印从客户端收到的数据
            byte[] request = call.request;
            String message = new String(request);
            LOG.info("received mseesage: " + message);

            //each channel write 2MB data for test
            int dataLength = 2 * 1024 * 1024;
            ByteBuffer buffer = ByteBuffer.allocate(4 + dataLength);

            buffer.putInt(dataLength);
            writeDataForTest(buffer);
            buffer.flip();

            //将buffer赋给call对象的response,为写回做准备
            call.response = buffer;
            responder.doResponse(call);
        }
    }
    //为了测试填充数据,前4个字节为,回写客户端数据的字节长度
    public void writeDataForTest(ByteBuffer buffer) {
        int n = buffer.limit() - 4;
        for(int i = 0; i < n; i++) {
            buffer.put((byte)0);
        }
    }


    class Call {
        Connection conn;
        byte[] request;
        Responder responder;
        ByteBuffer response;
        boolean done;
        public Call(Connection conn, byte[] request, Responder responder) {
            this.conn = conn;
            this.request = request;
            this.responder = responder;
        }
    }

    //当前buffer剩余的长度少于64M,就将数据直接读入buffer,如果大于64M,则循环读入(估计是为了防止一次性读太多,引起系统的开销过大,所以循环读)
    public int channelRead(ReadableByteChannel channel, ByteBuffer buffer) throws IOException {
        return buffer.remaining() <= NIO_BUFFER_LIMIT ? channel.read(buffer) : channleIO(channel, null, buffer);
    }

    public int channelWrite(WritableByteChannel channel, ByteBuffer buffer) throws IOException {
        return buffer.remaining() <= NIO_BUFFER_LIMIT ? channel.write(buffer) : channleIO(null, channel, buffer);
    }


    public int channleIO(ReadableByteChannel readCh, WritableByteChannel writeCh, ByteBuffer buffer) throws IOException {
        int initRemaining = buffer.remaining();
        int originalLimit = buffer.limit();

        int ret = 0;
        try {
            while (buffer.remaining() > 0) {
                int ioSize = Math.min(buffer.remaining(), NIO_BUFFER_LIMIT);
                buffer.limit(buffer.position() + ioSize);
                ret = readCh == null ? writeCh.write(buffer) : readCh.read(buffer);
                if (ret < ioSize) {
                    break;
                }
            }
        } finally {
            buffer.limit(originalLimit);
        }

        int byteRead = initRemaining - buffer.remaining();
        return byteRead > 0 ? byteRead : ret;
    }


    public void startHandler() {
        for(int i = 0; i < handler; i++) {
            new Handler(i).start();
        }
    }


    public void start() throws IOException {
        new Listener(10000).start();
        responder = new Responder();
        responder.start();
        startHandler();
        LOG.info("server startup! ");
    }

    public static void main(String[] args) throws IOException {
        Server server = new Server();
        server.start();
    }
}
客户端代码
package com.net.program.nio;


import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import javax.net.SocketFactory;
import java.io.*;
import java.net.InetSocketAddress;
import java.net.Socket;


/**
 * Created by wangkai8 on 17/1/6.
 */
public class Client {

     public static final Logger LOG = LogManager.getLogger();

    Socket socket;
    OutputStream out;
    InputStream in;

    public Client() throws IOException {
        socket = SocketFactory.getDefault().createSocket();
        socket.setTcpNoDelay(true);
        socket.setKeepAlive(true);
        InetSocketAddress server = new InetSocketAddress("localhost", 10000);
        socket.connect(server, 10000);
        out = socket.getOutputStream();
        in = socket.getInputStream();
    }


    public void send(String message) throws IOException {
        byte[] data = message.getBytes();
        DataOutputStream dos = new DataOutputStream(out);
        //头4个字节作为本次内容的长度
        dos.writeInt(data.length);
        //下面为具体的发送内容
        dos.write(data);
        out.flush();
    }

    //启动200个线程向服务器进行通讯,客户端采用阻塞的socket
    public static void main(String[] args) throws IOException {
        int n = 200;
        for(int i = 0; i < n; i++) {
            new Thread() {
                Client client = new Client();

                public void run() {
                    try {
                        client.send(getName() + "_xiaomiemie");

                        DataInputStream inputStream = new DataInputStream(client.in);
                        int dataLength = inputStream.readInt();
                        byte[] data = new byte[dataLength];
                        inputStream.readFully(data);
                        client.socket.close();
                        LOG.info("receive from server: dataLength=" + data.length);
                    } catch (IOException e) {
                        LOG.error("", e);
                    } catch (Exception e) {
                        LOG.error("", e);
                    }
                }
            }.start();
        }
    }

}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要使用 `ServerSocketChannel` 结合 `Selector` 和 `ThreadPoolExecutor` 实现多线程处理事务,你可以按照以下步骤进行操作: 1. 创建一个 `ServerSocketChannel` 实例,并将其注册到 `Selector` 中以便监听连接请求。 ```java ServerSocketChannel serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.bind(new InetSocketAddress(port)); serverSocketChannel.configureBlocking(false); Selector selector = Selector.open(); serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); ``` 2. 创建一个 `ThreadPoolExecutor` 实例,用于管理线程池。 ```java ThreadPoolExecutor executor = new ThreadPoolExecutor( corePoolSize, maxPoolSize, keepAliveTime, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), new ThreadPoolExecutor.CallerRunsPolicy() ); ``` 3. 进入事件循环,使用 `Selector` 监听事件并处理连接请求。 ```java while (true) { // 等待事件发生 selector.select(); // 处理事件 Set<SelectionKey> selectedKeys = selector.selectedKeys(); Iterator<SelectionKey> iterator = selectedKeys.iterator(); while (iterator.hasNext()) { SelectionKey key = iterator.next(); if (key.isAcceptable()) { // 处理连接请求,创建新的 SocketChannel ServerSocketChannel serverChannel = (ServerSocketChannel) key.channel(); SocketChannel socketChannel = serverChannel.accept(); socketChannel.configureBlocking(false); // 将新的 SocketChannel 注册到 Selector 上,以便监听读取事件 socketChannel.register(selector, SelectionKey.OP_READ); } else if (key.isReadable()) { // 处理读取事件,从 SocketChannel 中读取数据 SocketChannel socketChannel = (SocketChannel) key.channel(); // 创建一个任务,交给线程池执行 executor.execute(new Task(socketChannel)); } iterator.remove(); } } ``` 4. 创建一个 `Task` 类,用于在线程池中执行具体的任务。 ```java class Task implements Runnable { private SocketChannel socketChannel; public Task(SocketChannel socketChannel) { this.socketChannel = socketChannel; } @Override public void run() { // 处理具体的事务逻辑,读取数据并做相应的处理 // ... } } ``` 在上述代码中,我们使用 `Selector` 监听 `OP_ACCEPT` 事件和 `OP_READ` 事件。当有新的连接请求时,将新的 `SocketChannel` 注册到 `Selector` 上以便监听读取事件。当有数据可读时,将读取任务交给线程池中的线程执行。 通过这种方式,你可以实现一个多线程处理事务的服务端。每个客户端连接都可以在独立的线程中处理,从而提高并发性能。 希望对你有帮助!如果还有其他问题,请随时提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值