花了一个星期的时间,通过在网上查阅的资料,了解了Java NIO的一些基本原理,相关API的基本使用。虽说这东西并不新鲜,但毕竟是第一次接触。
NIO的非阻塞优势在于仅需要少量的线程、一定量的的线程池便可以支撑起一定量的客户端并发访问。
在了解其基本原理和相关API的使用之后,动手写了一个demo,主要的类有:
NIOServer:NIO服务器端,用于接收多个客户端请求,并且对多个客户端请求进行响应
NIOClient:客户端,访问服务端并且发送请求,接收服务端响应
ReadHandler:ServerReadHandler和ClientReadHandler的抽象类,主要进行OP_READ(读)事件的处理
ServerReadHandler:服务端用于处理OP_READ(读)事件的处理类
ClientReadHandler:客户端用于处理OP_READ(读)事件的处理类
MsgPacket:用于封装请求数据和响应数据的类
下面是具体的代码:
NIOServer
package com.study.nio;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
public class NIOServer {
// 缓存大小 最好设置大一些,但是也不要设得太大(不可少于2)
private static final int BUFFER_SIZE = 1024 * 8;
// 通道管理器
private Selector selector;
// 读处理线程池
private ThreadPoolExecutor readPoolExecutor;
// 分配读缓存
private ByteBuffer byteBuffer;
// 存放与客户端连接的SocketChannel和其待写队列
private Map<SocketChannel, LinkedList<ByteBuffer>> writeBufferMap;
// 存放与客户端连接的SocketChannel和ReadHandler(一个客户端对应一个ReadHandler)
private Map<SocketChannel, ReadHandler> SRMap;
// 现有连接数
private int connnectCount;
private String ip;
private int port;
public NIOServer() {}
public NIOServer(String ip, int port) {
this.ip = ip;
this.port = port;
}
// 初始化
private void init() {
final int processors = Runtime.getRuntime().availableProcessors();
readPoolExecutor = (ThreadPoolExecutor) Executors.newFixedThreadPool(processors * 2);
// 初始化分配读缓存
byteBuffer = ByteBuffer.allocate(BUFFER_SIZE);
// 初始化写队列
writeBufferMap = new ConcurrentHashMap<SocketChannel, LinkedList<ByteBuffer>>();
// 初始化SRMap
SRMap = new ConcurrentHashMap<SocketChannel, ReadHandler>();
// 初始化连接数
connnectCount = 0;
ServerSocketChannel serverSocketChannel = null;
try {
serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.socket().bind(new InetSocketAddress(ip, port));
serverSocketChannel.configureBlocking(false);
selector = Selector.open();
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
} catch (IOException e) {
e.printStackTrace();
}
}
// 启动
public void startup() {
// 完成初始化
init();
// 报告服务器启动
System.out.println(Report.reportCurrentTime() + "server startup...");
System.out.println(Report.reportCurrentTime() + "server listen on " + ip + " port " + port);
new Thread(new Runnable() {
@Override
public void run() {
// 一直跑
while(!Thread.interrupted()) {
try {
// 阻塞在这里 如果写成selector.select(1000)则最多阻塞1000ms
int nKey = selector.select();
if (nKey > 0) {
Set<SelectionKey> keySet = selector.selectedKeys();
Iterator<SelectionKey> iterator = keySet.iterator();
while (iterator.hasNext()) {
final SelectionKey key = iterator.next();
// 移除,避免重复处理
iterator.remove();
// 根据key的类型进行判断 OP_ACCEPT|OP_READ|OP_WRITE
if (key.isValid() && key.isAcceptable()) {
// 处理接收连接请求
acceptConnection(key);
} else if (key.isValid() && key.isReadable()) {
// 处理读操作
readFromChannel(key);
} else if(key.isValid() && key.isWritable()) {
// 处理写操作
writeToChannel(key);
}
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
}).start();
}
// 接收客户端连接请求
private void acceptConnection(SelectionKey key) {
ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key
.channel();
// 接受客户端连接
Socket socket = null;
try {
socket = serverSocketChannel
.accept().socket();
SocketChannel socketChannel = socket
.getChannel();
// 设置通道非阻塞
socketChannel.configureBlocking(false);
// 注册读权限
socketChannel.register(selector,
SelectionKey.OP_READ);
// 测试写数据
socketChannel.write(ByteBuffer.wrap(new MsgPacket(Report.reportCurrentTime() + "欢迎来到本地服务器").getBytes()));
System.out.println(Report.reportCurrentTime() + "accept one Client");
connnectCount++;
// 保存SocketChannel与处理该通道的ReadHandler
SRMap.put(socketChannel, new ServerReadHandler(socketChannel, this));
// 保存SocketChannel与其待写队列
writeBufferMap.put(socketChannel, new LinkedList<ByteBuffer>());
} catch (IOException e) {
e.printStackTrace();
}
}
// 从通道读
private void readFromChannel(SelectionKey key) {
final SocketChannel socketChannel = (SocketChannel) key.channel();
synchronized (byteBuffer) {
byteBuffer.clear();
try {
final int count = socketChannel.read(byteBuffer);
// System.out.println("count = " + count);
if(count > 0) { // 接收
final byte[] data = new byte[count];
System.arraycopy(byteBuffer.array(), 0, data, 0, count);
final ReadHandler readHandler = SRMap.get(socketChannel);
readHandler.read(data);
// 线程池处理
readPoolExecutor.execute(new Runnable() {
@Override
public void run() {
// readHandler.handle(array, count);
readHandler.handle();
}
});
} else if (count < 0) { // 客户端主动断开连接
connnectCount --;
// 释放资源
releaseResource(socketChannel);
socketChannel.close();
key.cancel();
System.out.println("客户端主动断开连接," + " 剩余连接数: " + connnectCount);
}
} catch (IOException e) {
e.printStackTrace();
try {
if(socketChannel != null && socketChannel.isOpen()) {
// 处理客户端异常断开
socketChannel.close();
}
// 取消感兴趣的事件
key.cancel();
// 移除与该socketChannel的资源
releaseResource(socketChannel);
} catch (IOException e1) {
e1.printStackTrace();
}
}
}
}
// 模拟响应
public synchronized void respone(SocketChannel socketChannel, String msg) {
MsgPacket msgPacket = new MsgPacket(msg);
LinkedList<ByteBuffer> bufferQueue = writeBufferMap.get(socketChannel);
// 添加到写队列
bufferQueue.add(ByteBuffer.wrap(msgPacket.getBytes()));
socketChannel.keyFor(this.selector).interestOps(SelectionKey.OP_WRITE);
// 唤醒
selector.wakeup();
}
// 往通道写
private synchronized void writeToChannel(SelectionKey key) {
SocketChannel socketChannel = (SocketChannel) key.channel();
LinkedList<ByteBuffer> bufferQueue = writeBufferMap.get(socketChannel);
while(!bufferQueue.isEmpty()) {
ByteBuffer buffer = bufferQueue.get(0);
try {
socketChannel.write(buffer);
if(buffer.remaining() > 0) {
// 该缓冲区中的字节还没有写完,break,让下一个write key继续写
break;
}
// 写完一个buffer
bufferQueue.remove(0);
} catch (IOException e) {
e.printStackTrace();
// 处理客户端异常断开
try {
if(socketChannel != null && socketChannel.isOpen()) {
socketChannel.close();
}
} catch (IOException e1) { // 关闭时可能遇到ClosedChannelException
e1.printStackTrace();
} finally {
// 取消感兴趣的事件
key.cancel();
// 释放资源
releaseResource(socketChannel);
}
}
}
if(bufferQueue.isEmpty()) {
// 全部数据写完了 取消写等待事件(不取消会造成cpu很快达到100%,因为OP_WRITE没有移除,seletor.select()不会阻塞,一直执行while死循环)
key.interestOps(SelectionKey.OP_READ);
}
}
// 释放资源
private synchronized void releaseResource(SocketChannel socketChannel) {
SRMap.remove(socketChannel);
// System.out.println("SRMap size = " + SRMap.size());
writeBufferMap.remove(socketChannel);
// System.out.println("writeBufferMap size = " + writeBufferMap.size());
}
public static void main(String[] args) {
NIOServer server = new NIOServer("127.0.0.1", 9000);
server.startup();
}
}
NIOClient
package com.study.nio;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Set;
public class NIOClient {
// 空闲等待最长时间
private static final int MAX_IDLE_COUNT = 60;
// 定义最大缓存区大小
private static final int BUFFER_SIZE = 1024 * 8;
// 是否关闭客户端的标志
private boolean isClosed = false;
// 通道管理器
private Selector selector;
// 与服务端交互的socket通道
private SocketChannel socketChannel;
// 分配的读缓存
private ByteBuffer readBuffer;
// 读处理器
private ReadHandler readHandler;
// 待写队列
private LinkedList<ByteBuffer> bufferQueue;
// 当前空闲计数
private int idleCount;
private String serverIP;
private int port;
private String name;
public String getName() {
return this.name;
}
public NIOClient() {}
public NIOClient(String serverIP, int port) {
this.serverIP = serverIP;
this.port = port;
}
public NIOClient(String name, String serverIP, int port) {
this.name = name;
this.serverIP = serverIP;
this.port = port;
}
// 完成初始化工作
private void init() {
idleCount = 0;
// 初始化读缓存
readBuffer = ByteBuffer.allocate(BUFFER_SIZE);
// 初始化待写队列
bufferQueue = new LinkedList<ByteBuffer>();
SocketChannel socketChannel = null;
try {
socketChannel = SocketChannel.open();
// 需要设置为非阻塞模式才能进行一系列操作
socketChannel.configureBlocking(false);
socketChannel.connect(new InetSocketAddress(serverIP, port));
selector = Selector.open();
socketChannel.register(selector, SelectionKey.OP_CONNECT);
} catch (IOException e) {
e.printStackTrace();
}
}
// 启动
public void startup() {
// 完成初始化
init();
new Thread(new Runnable() {
@Override
public void run() {
try {
while (!isClosed) {
int nKey = selector.select(1000); // 结合MAX_IDLE_COUNT 等价于 MAX_IDLE_COUNT(s)空闲时间检测
if (nKey > 0) {
idleCount = 0;
Set<SelectionKey> keySet = selector.selectedKeys();
Iterator<SelectionKey> iterator = keySet.iterator();
while (iterator.hasNext()) {
SelectionKey key = iterator.next();
iterator.remove();
if (key.isConnectable()) {
// 连接事件
finishedConnection(key);
} else if (key.isReadable()) {
// 读事件
readFromChanel(key);
} else if (key.isWritable()) {
// 写事件
writeToChannel(key);
}
}
} else {
idleCount++;
if(idleCount >= MAX_IDLE_COUNT) {
// 空闲超时,断开与客户端的连接
close();
}
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
}).start();
}
private void finishedConnection(SelectionKey key) {
SocketChannel socketChannel = (SocketChannel) key
.channel();
if (socketChannel.isConnectionPending()) {
try {
socketChannel.finishConnect();
socketChannel.configureBlocking(false);
// 注册读权限
socketChannel.register(selector,
SelectionKey.OP_READ);
this.socketChannel = socketChannel;
readHandler = new ClientReadHandler(this, this.socketChannel);
System.out.println(Report.reportCurrentTime() + this.name + " Connect to Server");
final String msg = "本地服务器你好!" + "我是" + this.getName();
final String msg1 = "我是" + this.getName() + " 本地服务器你好!";
send(msg);
send(msg1);
} catch (IOException e) {
e.printStackTrace();
}
}
}
private void readFromChanel(SelectionKey key) {
SocketChannel channel = (SocketChannel) key.channel();
// 清空读缓存
readBuffer.clear();
try {
int count = channel.read(readBuffer);
// System.out.println("count = " + count);
if(count > 0) {
byte[] data = new byte[count];
System.arraycopy(readBuffer.array(), 0, data, 0, count);
readHandler.read(data);
readHandler.handle();
}
} catch (IOException e) {
e.printStackTrace();
try {
socketChannel.close();
key.cancel();
} catch (IOException e1) {
e1.printStackTrace();
}
}
}
public boolean isClosed() {
return isClosed;
}
public void close() {
try {
// 使线程退出
isClosed = true;
selector.close();
socketChannel.close();
} catch (IOException e) {
e.printStackTrace();
}
}
public void send(String msg) {
MsgPacket msgPacket = new MsgPacket(msg);
// 添加到写队列
bufferQueue.add(ByteBuffer.wrap(msgPacket.getBytes()));
socketChannel.keyFor(this.selector).interestOps(SelectionKey.OP_WRITE);
// 唤醒
selector.wakeup();
}
// 往通道写
private void writeToChannel(SelectionKey key) {
SocketChannel socketChannel = (SocketChannel) key.channel();
while(!bufferQueue.isEmpty()) {
ByteBuffer buffer = bufferQueue.get(0);
try {
socketChannel.write(buffer);
if(buffer.remaining() > 0) {
break;
}
bufferQueue.remove(0);
} catch (IOException e) {
e.printStackTrace();
}
}
if(bufferQueue.isEmpty()) {
// 全部数据写完了 取消写等待事件
key.interestOps(SelectionKey.OP_READ);
}
}
public static void main(String[] args) {
NIOClient client = new NIOClient("127.0.0.1", 9000);
client.startup();
System.out.println(Report.reportCurrentTime() + " client startup");
}
}
ReadHandler
package com.study.nio;
import java.nio.channels.SocketChannel;
import java.util.LinkedList;
public abstract class ReadHandler {
// 存放接收的字节数组
protected LinkedList<byte[]> dataList = new LinkedList<byte[]>();
// 存放读取的字节数
protected int readCount = 0;
// 存放未处理完成的包
protected LinkedList<MsgPacket> packetQueue = new LinkedList<MsgPacket>();
// 对应的SocketChannel
protected SocketChannel socketChannel;
public ReadHandler() {}
public ReadHandler(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
}
public synchronized void read(byte[] data) {
dataList.add(data);
readCount = readCount + data.length;
}
public synchronized void handle() {
if(readCount == 0) {
return ;
}
boolean flag = true;
byte[] bytes = toBytes();
// System.out.println("handle data = " + new String(bytes));
while (flag) {
switch (bytes[0]) {
case MsgPacket.MSG_FLAG:
if (bytes.length < MsgPacket.HEADER_SIZE) {
// 此时连包的header都没有接收完
MsgPacket msgPacket = new MsgPacket(bytes);
packetQueue.add(msgPacket);
flag = false;
} else {
byte[] header = new byte[MsgPacket.HEADER_SIZE];
header[0] = MsgPacket.MSG_FLAG;
header[1] = bytes[1];
header[2] = bytes[2];
MsgPacket msgPacket = new MsgPacket(header);
bytes = adjust(bytes, bytes.length - header.length);
if (bytes.length - header.length == 0) {
// 刚好是接收了一个包的header部分
packetQueue.add(msgPacket);
flag = false;
} else {
int more = bytes.length;
int length = bytes.length > msgPacket.getNeedDataLength() ? msgPacket.getNeedDataLength() : bytes.length;
byte[] newdata = new byte[length];
// 如果除header外剩余的数据大于msgPacket中记录的数据长度,那么证明还有剩余的包
System.arraycopy(bytes, 0, newdata, 0, length);
msgPacket.read(newdata);
if (msgPacket.isCompleted()) {
// 包完整地处理了,处理响应动作
response(msgPacket);
// 处理完一个包,计算剩余长度
more = more - length;
if (more > 0) {
// 还有其他包数据,重新调整bytes数组
bytes = adjust(bytes, more);
} else { // 刚好是处理一个包的长度
flag = false;
}
} else {
// 包的数据部分接收不完整
// 放到队列再进行处理
packetQueue.add(msgPacket);
flag = false;
}
}
}
break;
default:
// 发送的数据包括上一次未完成的部分,接着处理前面未处理完的包
if(!packetQueue.isEmpty()) {
MsgPacket msgPacket = packetQueue.get(0);
// 可能的情况: 1.不知道包长度(header没有接受完整) 2.知道包长度
if(!msgPacket.isHeaderCompleted()) { // 包header没有接收完整
// 包的header部分接收不完整
byte[] unCompletedHeader = msgPacket.getHeader();
// 现有header长度
int curLength = unCompletedHeader.length;
byte[] header = new byte[MsgPacket.HEADER_SIZE];
// 需要补全的header长度
int needLength = header.length - unCompletedHeader.length;
// 复制原来的
System.arraycopy(unCompletedHeader, 0, header, 0, curLength);
// 加上不足的
System.arraycopy(bytes, 0, header, curLength, needLength);
// 重新设置header,计算数据长度和包长度
msgPacket.resetHeader(header);
msgPacket.calLength();
// 重新调整bytes数组(去掉header)
int more = bytes.length - needLength;
bytes = adjust(bytes, more);
}
int more = bytes.length;
int length = bytes.length > msgPacket.getNeedDataLength() ? msgPacket.getNeedDataLength() : bytes.length;
byte[] newdata = new byte[length];
// 如果除header外剩余的数据大于msgPacket中记录的数据长度,那么证明还有剩余的包
System.arraycopy(bytes, 0, newdata, 0, length);
msgPacket.read(newdata);
if (msgPacket.isCompleted()) {
// 包完整地处理了,进行相关响应动作
response(msgPacket);
// 退出队列
packetQueue.remove(0);
// 处理完一个包,计算剩余长度
more = more - length;
if (more > 0) {
// 还有其他包数据,重新调整bytes数组
bytes = adjust(bytes, more);
} else { // 刚好是处理一个包的长度
flag = false;
}
} else {
// 包的数据部分接收不完整
// 本来就已经在队列里面了 不需要再add
// packetQueue.add(msgPacket);
flag = false;
}
}
break;
}
}
}
// 将收到的字节转换为一个数组,进行处理
private byte[] toBytes() {
byte[] bytes = new byte[readCount];
int destPos = 0;
while(!dataList.isEmpty()) {
byte[] bytes0 = dataList.remove(0);
System.arraycopy(bytes0, 0, bytes, destPos, bytes0.length);
destPos = destPos + bytes0.length;
}
readCount = 0;
return bytes;
}
// 调整bytes中剩余的字节到新的数组,并返回
protected byte[] adjust(byte[] bytes, int more) {
byte[] tmp = new byte[more];
int completed = bytes.length - more;
// 从后往前复制
for(int i = bytes.length - 1,j = tmp.length - 1; i >= completed; i--,j--) {
tmp[j] = bytes[i];
}
return tmp;
}
// 收到数据后的响应动作
protected abstract void response(MsgPacket msgPacket);
}
ServerHandler
package com.study.nio;
import java.nio.channels.SocketChannel;
public class ServerReadHandler extends ReadHandler {
private NIOServer nioServer;
public ServerReadHandler(SocketChannel socketChannel, NIOServer nioServer) {
super(socketChannel);
this.nioServer = nioServer;
}
protected synchronized void response(MsgPacket msgPacket) {
final String content = new String(msgPacket.getData());
System.out.println(Report.reportCurrentTime() + "receive content -> "
+ content);
// 模拟耗时操作
try {
Thread.sleep(100);
// 模拟一下回复客户端
nioServer.respone(socketChannel, Report.reportCurrentTime() + "Server reply -> " + content);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
ClientReadHandler
package com.study.nio;
import java.nio.channels.SocketChannel;
public class ClientReadHandler extends ReadHandler {
private NIOClient nioClient;
public ClientReadHandler(NIOClient nioClient, SocketChannel socketChannel) {
super(socketChannel);
this.nioClient = nioClient;
}
protected void response(MsgPacket msgPacket) {
String content = new String(msgPacket.getData());
System.out.println(Report.reportCurrentTime() + nioClient.getName() + " receive content -> "
+ content);
}
}
package com.study.nio;
/**
* 该包用于封装发送的消息和读取发送的消息
* @author CrazyPig
*
*/
public class MsgPacket {
public static final byte MSG_FLAG = 0x01;
public static final int HEADER_SIZE = 3;
private byte[] header = new byte[HEADER_SIZE]; // 第一个字节为包类型,后面两个字节表示数据长度
private byte[] data; // 数据
private int length; // 包长度
private int curDataLength = 0; // 当前data[]长度
public byte[] getData() {
return this.data;
}
public byte[] getHeader() {
return this.header;
}
public int getPacketLength() {
return this.length;
}
public int getDataLength() {
return this.length - header.length;
}
public int getCurDataLength() {
return this.curDataLength;
}
// 获取还需要填充的数据字节长度
public int getNeedDataLength() {
return this.getDataLength() - this.getCurDataLength();
}
// 根据发送的消息构造一个包
public MsgPacket(String msg) {
this.data = msg.getBytes();
this.length = data.length + header.length;
int dlen = data.length;
genHeader(dlen);
}
// 根据收到的header构造一个包
public MsgPacket(byte[] header) {
this.header = header;
if(header.length < HEADER_SIZE) {
return ;
}
// 求数据长度
calLength();
}
// 求数据长度
public void calLength() {
byte high = (byte) ((header[1] << 8) & 0xff00);
byte low = (byte) (header[2] & 0x00ff);
int dataLength = (high | low);
this.data = new byte[dataLength];
this.length = dataLength + header.length;
}
public void read(byte[] newdata) {
// 复制newdata[]数组内容到data[]数组
System.arraycopy(newdata, 0, this.data, this.curDataLength, newdata.length);
this.curDataLength += newdata.length;
}
// 判断包是否完整
public boolean isCompleted() {
int curLength = this.curDataLength + this.header.length;
return curLength == this.length;
}
// 判断header是否完整
public boolean isHeaderCompleted() {
return header.length == HEADER_SIZE;
}
public void genHeader(int dataLength) {
header[0] = MSG_FLAG;
// header[1] 高字节
header[1] = (byte) ((dataLength & 0xff00) >> 8);
header[2] = (byte) (dataLength & 0x00ff);
}
// 返回整个包的数据
public byte[] getBytes() {
byte[] allBytes = new byte[this.length];
System.arraycopy(header, 0, allBytes, 0, this.header.length);
System.arraycopy(data, 0, allBytes, this.header.length, this.data.length);
return allBytes;
}
// 设置header
public void resetHeader(byte[] header) {
this.header = header;
}
}
package com.study.nio;
public class NIOPowerTest {
public static final int CLIENT_SIZE = 300;
public static void main(String[] args) {
NIOClient[] client = new NIOClient[CLIENT_SIZE];
for(int i = 0; i < client.length; i++) {
client[i] = new NIOClient("Client" + i, "127.0.0.1", 9000);
client[i].startup();
System.out.println(Report.reportCurrentTime() + " CLIENT" + i + " startup");
}
}
}