重点:
selector 的regist 和 select 方法是用的同一个锁,所以需要注意避免阻塞
https://bugs.openjdk.java.net/browse/JDK-6446653
以下DEMO基本思路为tomcat实现
备注:由于java的nio是采用水平触发,为了避免不断的write事件,需要引入blockingQueue来控制写流程
服务端
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
/**
* NIO服务端
*/
public class NIOServer implements Runnable {
private static final Integer size = 10;
AtomicInteger count = new AtomicInteger();
AtomicInteger pollCount = new AtomicInteger();
AtomicInteger acceptCount = new AtomicInteger();
ArrayBlockingQueue<ChannelWrapper> queue = new ArrayBlockingQueue<>(1000);
// 通道管理器
private Selector selector;
private ByteBuffer readBuffer = ByteBuffer.allocate(2);
private ByteBuffer writeBuffer = ByteBuffer.allocate(1024);
private ExecutorService pollEs = Executors.newFixedThreadPool(size);
private ExecutorService acceptEs = Executors.newFixedThreadPool(10);
private ServerSocketChannel serverChannel;
public NIOServer(int port) {
init(port);
}
/**
* 启动服务端测试
*/
public static void main(String[] args) throws IOException {
new Thread(new NIOServer(9999)).start();
}
private void init(int port) {
try {
System.out.println("Server staring at port: " + port);
// 开启多路复用器
this.selector = Selector.open();
// 开启服务通道
serverChannel = ServerSocketChannel.open();
serverChannel.configureBlocking(true);
serverChannel.bind(new InetSocketAddress(port));
System.out.println("Server started");
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void run() {
acceptEs.submit(this::doAccept);
pollEs.submit(this::doRun);
}
private void doAccept() {
while (true) {
System.out.println(Thread.currentThread() + "---------------------------------------accept---------------------------------------------------");
try {
SocketChannel channel = serverChannel.accept();
channel.configureBlocking(false);
ChannelWrapper wrapper = new ChannelWrapper();
wrapper.channel = channel;
wrapper.first = true;
queue.put(wrapper);
} catch (Exception e) {
e.printStackTrace();
}
}
}
private void doRun() {
while (true) {
System.out.println(Thread.currentThread() + "---------------------------------------poll---------------------------------------------------");
try {
ChannelWrapper wrapper = queue.take();
if (wrapper.first) {
wrapper.channel.configureBlocking(false);
wrapper.channel.register(this.selector, SelectionKey.OP_READ);
}
this.selector.select();
Iterator<SelectionKey> keys = this.selector.selectedKeys().iterator();
while (keys.hasNext()) {
SelectionKey key = keys.next();
keys.remove();
if (key.isValid()) {
System.out.println("-------------poll select-------------" + key.isAcceptable() + key.isReadable() + key.isWritable());
count.incrementAndGet();
try {
if (key.isReadable()) {
System.out.println(Thread.currentThread() + "- " + key + "read --- " + count.get());
read(key);
}
if (key.isWritable()) {
System.out.println(Thread.currentThread() + "- " + key + "write --- " + count.get());
write(key);
}
} catch (Exception e) {
key.cancel();
}
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
private void accept(SelectionKey key) {
try {
ServerSocketChannel serverChannel = (ServerSocketChannel)key.channel();
SocketChannel channel = serverChannel.accept();
channel.configureBlocking(false);
channel.register(this.selector, SelectionKey.OP_READ);
} catch (IOException e) {
e.printStackTrace();
}
}
private void read(SelectionKey key) {
try {
StringBuilder input = new StringBuilder();
this.readBuffer.clear();
SocketChannel channel = (SocketChannel)key.channel();
// do read util empty
while (true) {
int readLength = channel.read(readBuffer);
this.readBuffer.clear();
byte[] datas = new byte[readBuffer.remaining()];
readBuffer.get(datas);
input.append(new String(datas, "UTF-8"));
if (readLength <= 0) {
break;
}
this.readBuffer.flip();
}
System.out.println("From Clinet -------------> " + input.toString());
channel.register(this.selector, SelectionKey.OP_WRITE);
ChannelWrapper wrapper = new ChannelWrapper();
wrapper.channel = channel;
queue.put(wrapper);
} catch (Exception e) {
e.printStackTrace();
}
}
private void write(SelectionKey key) {
this.writeBuffer.clear();
SocketChannel channel = (SocketChannel)key.channel();
try {
String line = "server response" + count.get();
writeBuffer.put(line.getBytes("UTF-8"));
writeBuffer.flip();
channel.write(writeBuffer);
System.out.println("To Clinet write: " + new String(line.getBytes("UTF-8"), "UTF-8"));
channel.close();
} catch (IOException e) {
e.printStackTrace();
}
}
class ChannelWrapper {
SocketChannel channel;
Boolean first = false;
}
}
客户端
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.concurrent.atomic.AtomicInteger;
/**
* NIO客户端
*/
public class NIOClient {
/**
* 启动客户端测试
*/
@SuppressWarnings("resource")
public static void main(String[] args) throws IOException {
AtomicInteger count = new AtomicInteger(5000);
long start = System.currentTimeMillis();
while (true) {
if (count.decrementAndGet() < 0) {
break;
}
dorequest();
}
long end = System.currentTimeMillis();
System.out.println((end-start));
}
private static void dorequest() throws IOException {
InetSocketAddress remote = new InetSocketAddress("localhost", 9999);
SocketChannel channel = null;
ByteBuffer writeBuffer = ByteBuffer.allocate(1024);
ByteBuffer readBuffer = ByteBuffer.allocate(1024);
AtomicInteger count = new AtomicInteger();
try {
channel = SocketChannel.open();
channel.connect(remote);
// write all data
System.out.println("send to server>");
String line = "request " + count.incrementAndGet();
if (line.length() == 0) {
line = "empty";
}
writeBuffer.put(line.getBytes("UTF-8"));
writeBuffer.flip();
channel.write(writeBuffer);
writeBuffer.clear();
System.out.println("response start ->");
StringBuilder response = new StringBuilder();
while (true) {
int readLength = channel.read(readBuffer);
readBuffer.flip();
byte[] datas = new byte[readBuffer.remaining()];
readBuffer.get(datas);
response.append(new String(datas, "UTF-8"));
readBuffer.clear();
System.out.println("read - " + readLength + " - limit - " + readBuffer.limit());
if (readLength <= 0) {
break;
}
}
System.out.println("From server -> " + response.toString());
} catch (IOException e) {
e.printStackTrace();
} finally {
if (null != channel) {
channel.close();
}
}
}
}