TTransport=>TIOStreamTransport=>TSocket
重要参数设置:
socket_.setSoLinger(false, 0);
socket_.setTcpNoDelay(true);
socket_.setSoTimeout(timeout_);//客户端读取超时时间
socket_.connect(new InetSocketAddress(host_, port_), timeout_);//客户端连接超时时间
public abstract class TTransport {
//底层实现socket_.isConnected()
public abstract boolean isOpen();
//底层实现socket_.isConnected()
public boolean peek() {
return isOpen();
}
//socket_.connect(new InetSocketAddress(host_, port_), timeout_);
public abstract void open()
throws TTransportException;
//socket_.close();
public abstract void close();
//读取len个字节到buf buf中起始位置off 返回实际读取的字节数
public abstract int read(byte[] buf, int off, int len)
throws TTransportException;
//读满len个字节到buf 起始位置off
public int readAll(byte[] buf, int off, int len)
throws TTransportException {
int got = 0;
int ret = 0;
while (got < len) {
ret = read(buf, off+got, len-got);
if (ret <= 0) {
throw new TTransportException(
"Cannot read. Remote side has closed. Tried to read "
+ len
+ " bytes, but only got "
+ got
+ " bytes. (This is often indicative of an internal error on the server side. Please check your server logs.)");
}
got += ret;
}
return got;
}
//写buf到流
public void write(byte[] buf) throws TTransportException {
write(buf, 0, buf.length);
}
//写buf到流起始位置off 长度len
public abstract void write(byte[] buf, int off, int len)
throws TTransportException;
//flush刷新
public void flush()
throws TTransportException {}
//下面基于NIO实现
//获取到buffer
public byte[] getBuffer() {
return null;
}
//获取buffer位置
public int getBufferPosition() {
return 0;
}
//读取buffer中剩余的字节
public int getBytesRemainingInBuffer() {
return -1;
}
public void consumeBuffer(int len) {}
}
TTransport=>TNonblockingTransport=>TNonblockingSocket
重要参数:
Selector selector = SelectorProvider.provider().openSelector(); //创建选择器
SocketChannel socketChannel = SocketChannel.open();//创建ScoketChannel 同Socket 里面包装了Socket
socketChannel.configureBlocking(false);//设置为非阻塞
Socket socket = socketChannel.socket();//获取包装的Socket 设置底层的行为
socket.setSoLinger(false, 0);
socket.setTcpNoDelay(true);
setTimeout(timeout);
socketChannel_.register(selector,SelectionKey.OP_CONNECT);//注册连接事件
public abstract class TNonblockingTransport extends TTransport {
//底层实现socketChannel_.connect(socketAddress_)
public abstract boolean startConnect() throws IOException;
//底层实现socketChannel_.finishConnect()
public abstract boolean finishConnect() throws IOException;//完成后可以注册读.写事件
public abstract SelectionKey registerSelector(Selector selector, int interests) throws IOException;
public abstract int read(ByteBuffer buffer) throws IOException;
public abstract int write(ByteBuffer buffer) throws IOException;
}
NIO客户端Demo
package sampleNio;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.Iterator;
/**
* @author jason
*
*/
public class NioClient implements Runnable {
private InetAddress hostAddress;
private int port;
private Selector selector;
private ByteBuffer readBuffer = ByteBuffer.allocate(8192);
private ByteBuffer outBuffer = ByteBuffer.wrap("nice to meet you"
.getBytes());
public NioClient(InetAddress hostAddress, int port) throws IOException {
this.hostAddress = hostAddress;
this.port = port;
initSelector();
}
public static void main(String[] args) {
try {
NioClient client = new NioClient(
InetAddress.getByName("localhost"), 9090);
new Thread(client).start();
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void run() {
while (true) {
try {
selector.select();
Iterator<?> selectedKeys = selector.selectedKeys().iterator();
while (selectedKeys.hasNext()) {
SelectionKey key = (SelectionKey) selectedKeys.next();
selectedKeys.remove();
if (!key.isValid()) {
continue;
}
if (key.isConnectable()) {
finishConnection(key);
} else if (key.isReadable()) {
read(key);
} else if (key.isWritable()) {
write(key);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
private void initSelector() throws IOException {
// 创建一个selector
selector = SelectorProvider.provider().openSelector();
// 打开SocketChannel
SocketChannel socketChannel = SocketChannel.open();
// 设置为非阻塞
socketChannel.configureBlocking(false);
// 连接指定IP和端口的地址
socketChannel
.connect(new InetSocketAddress(this.hostAddress, this.port));
// 用selector注册套接字,并返回对应的SelectionKey,同时设置Key的interest set为监听服务端已建立连接的事件
socketChannel.register(selector, SelectionKey.OP_CONNECT);
}
private void finishConnection(SelectionKey key) throws IOException {
SocketChannel socketChannel = (SocketChannel) key.channel();
try {
// 判断连接是否建立成功,不成功会抛异常
socketChannel.finishConnect();
} catch (IOException e) {
key.cancel();
return;
}
// 设置Key的interest set为OP_WRITE事件
key.interestOps(SelectionKey.OP_WRITE);
}
/**
* 处理read
*
* @param key
* @throws IOException
*/
private void read(SelectionKey key) throws IOException {
SocketChannel socketChannel = (SocketChannel) key.channel();
readBuffer.clear();
int numRead;
try {
numRead = socketChannel.read(readBuffer);
} catch (Exception e) {
key.cancel();
socketChannel.close();
return;
}
if (numRead == 1) {
System.out.println("close connection");
socketChannel.close();
key.cancel();
return;
}
// 处理响应
handleResponse(socketChannel, readBuffer.array(), numRead);
}
/**
* 处理响应
*
* @param socketChannel
* @param data
* @param numRead
* @throws IOException
*/
private void handleResponse(SocketChannel socketChannel, byte[] data,
int numRead) throws IOException {
byte[] rspData = new byte[numRead];
System.arraycopy(data, 0, rspData, 0, numRead);
System.out.println(new String(rspData));
socketChannel.close();
socketChannel.keyFor(selector).cancel();
}
/**
* 处理write
*
* @param key
* @throws IOException
*/
private void write(SelectionKey key) throws IOException {
SocketChannel socketChannel = (SocketChannel) key.channel();
socketChannel.write(outBuffer);
if (outBuffer.remaining() > 0) {
return;
}
// 设置Key的interest set为OP_READ事件
key.interestOps(SelectionKey.OP_READ);
}
}
TServerTransport=>TNonblockingServerTransport=>TNonblockingServerSocket
public abstract class TServerTransport {
public abstract void listen() throws TTransportException;
public final TTransport accept() throws TTransportException {
TTransport transport = acceptImpl();
if (transport == null) {
throw new TTransportException("accept() may not return NULL");
}
return transport;
}
public abstract void close();
protected abstract TTransport acceptImpl() throws TTransportException;
/**
* Optional method implementation. This signals to the server transport
* that it should break out of any accept() or listen() that it is currently
* blocked on. This method, if implemented, MUST be thread safe, as it may
* be called from a different thread context than the other TServerTransport
* methods.
*/
public void interrupt() {}
}
public abstract class TNonblockingServerTransport extends TServerTransport {
public abstract void registerSelector(Selector selector);
}
public class TNonblockingServerSocket extends TNonblockingServerTransport {
private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingServerTransport.class.getName());
/**
* This channel is where all the nonblocking magic happens.
*/
private ServerSocketChannel serverSocketChannel = null;
/**
* Underlying ServerSocket object
*/
private ServerSocket serverSocket_ = null;
/**
* Timeout for client sockets from accept
*/
private int clientTimeout_ = 0;
/**
* Creates just a port listening server socket
*/
public TNonblockingServerSocket(int port) throws TTransportException {
this(port, 0);
}
/**
* Creates just a port listening server socket
*/
public TNonblockingServerSocket(int port, int clientTimeout) throws TTransportException {
this(new InetSocketAddress(port), clientTimeout);
}
public TNonblockingServerSocket(InetSocketAddress bindAddr) throws TTransportException {
this(bindAddr, 0);
}
public TNonblockingServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException {
clientTimeout_ = clientTimeout;
try {
serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.configureBlocking(false);
// Make server socket
serverSocket_ = serverSocketChannel.socket();
// Prevent 2MSL delay problem on server restarts
serverSocket_.setReuseAddress(true);
// Bind to listening port
serverSocket_.bind(bindAddr);
} catch (IOException ioe) {
serverSocket_ = null;
throw new TTransportException("Could not create ServerSocket on address " + bindAddr.toString() + ".");
}
}
public void listen() throws TTransportException {
// Make sure not to block on accept
if (serverSocket_ != null) {
try {
serverSocket_.setSoTimeout(0);
} catch (SocketException sx) {
sx.printStackTrace();
}
}
}
protected TNonblockingSocket acceptImpl() throws TTransportException {
if (serverSocket_ == null) {
throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket.");
}
try {
SocketChannel socketChannel = serverSocketChannel.accept();
if (socketChannel == null) {
return null;
}
TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel);
tsocket.setTimeout(clientTimeout_);
return tsocket;
} catch (IOException iox) {
throw new TTransportException(iox);
}
}
//注册连接接收事件
public void registerSelector(Selector selector) {
try {
// Register the server socket channel, indicating an interest in
// accepting new connections
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
} catch (ClosedChannelException e) {
// this shouldn't happen, ideally...
// TODO: decide what to do with this.
}
}
public void close() {
if (serverSocket_ != null) {
try {
serverSocket_.close();
} catch (IOException iox) {
LOGGER.warn("WARNING: Could not close server socket: " + iox.getMessage());
}
serverSocket_ = null;
}
}
//可能存在线程安全问题 虽然java文档声称线程安全的
public void interrupt() {
// The thread-safeness of this is dubious, but Java documentation suggests
// that it is safe to do this from a different thread context
close();
}
}
TTransport=>TFramedTransport
//封装消息体之前带的帧大小4个字节。
public class TFramedTransport extends TTransport {
protected static final int DEFAULT_MAX_LENGTH = 16384000;
private int maxLength_;
/**
* Underlying transport
*/
private TTransport transport_ = null;
/**
* Buffer for output
*/
private final TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(
1024);
/**
* Buffer for input
*/
private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(
new byte[0]);
public static class Factory extends TTransportFactory {
private int maxLength_;
public Factory() {
maxLength_ = TFramedTransport.DEFAULT_MAX_LENGTH;
}
public Factory(int maxLength) {
maxLength_ = maxLength;
}
@Override
public TTransport getTransport(TTransport base) {
return new TFramedTransport(base, maxLength_);
}
}
/**
* Constructor wraps around another transport
*/
public TFramedTransport(TTransport transport, int maxLength) {
transport_ = transport;
maxLength_ = maxLength;
}
public TFramedTransport(TTransport transport) {
transport_ = transport;
maxLength_ = TFramedTransport.DEFAULT_MAX_LENGTH;
}
public void open() throws TTransportException {
transport_.open();
}
public boolean isOpen() {
return transport_.isOpen();
}
public void close() {
transport_.close();
}
public int read(byte[] buf, int off, int len) throws TTransportException {
if (readBuffer_ != null) {
int got = readBuffer_.read(buf, off, len);
if (got > 0) {
return got;
}
}
// Read another frame of data
readFrame();
return readBuffer_.read(buf, off, len);
}
@Override
public byte[] getBuffer() {
return readBuffer_.getBuffer();
}
@Override
public int getBufferPosition() {
return readBuffer_.getBufferPosition();
}
@Override
public int getBytesRemainingInBuffer() {
return readBuffer_.getBytesRemainingInBuffer();
}
@Override
public void consumeBuffer(int len) {
readBuffer_.consumeBuffer(len);
}
private final byte[] i32buf = new byte[4];
private void readFrame() throws TTransportException {
transport_.readAll(i32buf, 0, 4);
int size = decodeFrameSize(i32buf);
if (size < 0) {
throw new TTransportException("Read a negative frame size (" + size
+ ")!");
}
if (size > maxLength_) {
throw new TTransportException("Frame size (" + size
+ ") larger than max length (" + maxLength_ + ")!");
}
byte[] buff = new byte[size];
transport_.readAll(buff, 0, size);
readBuffer_.reset(buff);
}
public void write(byte[] buf, int off, int len) throws TTransportException {
writeBuffer_.write(buf, off, len);
}
@Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer_.get();
int len = writeBuffer_.len();
writeBuffer_.reset();
encodeFrameSize(len, i32buf);
transport_.write(i32buf, 0, 4);
transport_.write(buf, 0, len);
transport_.flush();
}
public static final void encodeFrameSize(final int frameSize,
final byte[] buf) {
buf[0] = (byte) (0xff & (frameSize >> 24));
buf[1] = (byte) (0xff & (frameSize >> 16));
buf[2] = (byte) (0xff & (frameSize >> 8));
buf[3] = (byte) (0xff & (frameSize));
}
public static final int decodeFrameSize(final byte[] buf) {
return ((buf[0] & 0xff) << 24) | ((buf[1] & 0xff) << 16)
| ((buf[2] & 0xff) << 8) | ((buf[3] & 0xff));
}
public static void main(String[] args) {
int number = 99999;
byte[] buf = new byte[4];
encodeFrameSize(number,buf);
System.out.println(decodeFrameSize(buf));
}
}