netty websocket使用

一.maven依赖

<!-- springboot的依赖,如果系统不使用web相关的功能,也可以不使用 -->
<dependency>
	<groupId>org.springframework.boot</groupId>
	<artifactId>spring-boot-starter-web</artifactId>
	<version>${springboot.version}</version>
</dependency>

<dependency>
    <groupId>org.projectlombok</groupId>
    <artifactId>lombok</artifactId>
    <version>${lombok.version}</version>
    <scope>provided</scope>
</dependency>

<dependency>
    <groupId>com.alibaba.fastjson2</groupId>
    <artifactId>fastjson2</artifactId>
    <version>${fastjson2.version}</version>
</dependency>

<dependency>
    <groupId>io.netty</groupId>
    <artifactId>netty-all</artifactId>
    <version>${netty.version}</version>
</dependency>
<!-- 握手前校验使用,也可以使用其他的校验方式 -->
<dependency>
    <groupId>com.auth0</groupId>
    <artifactId>java-jwt</artifactId>
    <version>${jwt.version}</version>
</dependency>

二.包结构

demo包结构

三.demo代码

1.基础架构层代码

(1)netty功能基础实现和接口
socket基础接口
package com.zzc.netty.infrastructure.netty;

import com.zzc.netty.infrastructure.netty.config.SocketConfig;

public interface Socket<C extends SocketConfig> {

    boolean start();

    boolean start(C serverConfig, WebSocketChannelHandler webSocketChannelHandler);

    boolean isStarted();

    void close();

    C getConfig();

    Socket setConfig(C config);

    boolean isServer();
}
socket基础实现和初始化
package com.zzc.netty.infrastructure.netty;

import com.zzc.netty.infrastructure.netty.config.SocketConfig;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.handler.timeout.IdleStateHandler;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.TimeUnit;

@Slf4j
public abstract class BaseSocket<C extends SocketConfig> implements Socket<C> {

    private WebSocketChannelHandler webSocketChannelHandler;

    private C config;

    private boolean server = true;

    private boolean started = false;

    public BaseSocket(boolean server) {
        this.server = server;
    }

    @Override
    public boolean start() {
        C serverConfig = getConfig();
        if (serverConfig == null) {
            throw new RuntimeException("serverConfig is null.");
        }
        if (getWebSocketChannelHandler() == null) {
            throw new RuntimeException("OnChannelHandler is null.");
        }
        boolean result = false;
        try {
            result = doStart(serverConfig);
        } catch (Exception e) {
            throw new RuntimeException("start server error");
        }
        setStarted(result);
        return result;
    }

    @Override
    public boolean start(C serverConfig, WebSocketChannelHandler webSocketChannelHandler) {
        this.webSocketChannelHandler = webSocketChannelHandler;
        setConfig(serverConfig);
        return start();
    }

    @Override
    public boolean isStarted() {
        return started;
    }

    @Override
    public void close() {
        if (isStarted()) {
            log.info("stop server");
            try {
                doClose();
            } catch (Exception e) {
                log.error("stop server error.", e);
            }
        }
    }

    @Override
    public C getConfig() {
        return config;
    }

    protected WebSocketChannelHandler getWebSocketChannelHandler() {
        return this.webSocketChannelHandler;
    }

    @Override
    public Socket setConfig(C config) {
        if (config == null) {
            throw new RuntimeException("conf is null");
        }
        this.config = config;
        return this;
    }

    @Override
    public boolean isServer() {
        return server;
    }


    protected void setStarted(boolean started) {
        this.started = started;
    }

    protected void addNettyIdleHandler(ChannelPipeline pipeline) {
        long readIdleTimeout = getReadIdleTimeout();
        long writeIdleTimeout = getWriteIdleTimeout();
        long allIdleTimeout = getAllIdleTimeout();
        pipeline.addLast(new IdleStateHandler(readIdleTimeout, writeIdleTimeout, allIdleTimeout, TimeUnit.MILLISECONDS));
        pipeline.addLast(new SocketIdleStateTrigger());
    }

    protected long getReadIdleTimeout() {
        C conf = getConfig();
        return conf.getCloseTimeout();
    }

    protected long getWriteIdleTimeout() {
        C conf = getConfig();
        long closeTimeout = conf.getCloseTimeout();
        return Math.min(Math.max(closeTimeout / 4, (15 * 1000)), (closeTimeout / 2));
    }

    protected long getAllIdleTimeout() {
        C conf = getConfig();
        long closeTimeout = conf.getCloseTimeout();
        return closeTimeout + 500;
    }

    protected abstract boolean doStart(C serverConfig);

    protected abstract void doClose();

    protected void addNettyOtherHandler(ChannelPipeline pipeline) {
        //TODO
    }

    @ChannelHandler.Sharable
    class SocketIdleStateTrigger extends ChannelInboundHandlerAdapter {

        @Override
        public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
            if (evt instanceof IdleStateEvent) {
                IdleState state = ((IdleStateEvent) evt).state();

                /*if (state == IdleState.WRITER_IDLE) {
                    handleWriteIdle(ctx);
                } else if (state == IdleState.READER_IDLE) {
                    handleReadIdle(ctx);
                } else if (state == IdleState.ALL_IDLE) {
                    // 太长时间无收发消息,一般要做断开连接
                    handleAllIdle(ctx);
                }*/
                getWebSocketChannelHandler().onIdleStateEvent(ctx.channel(), state);
            } else {
                super.userEventTriggered(ctx, evt);
            }
        }
    }
}
消息或指令接收接口
package com.zzc.netty.infrastructure.netty;

import io.netty.channel.Channel;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.timeout.IdleState;

import java.util.Map;

/**
 * 消息或指令接收接口
 */
public interface WebSocketChannelHandler {

    boolean beforeHandshake(Channel channel, Map<String, Object> params);

    boolean afterHandshake(Channel channel, Map<String, Object> params);

    void channelActive(Channel channel);

    void channelInactive(Channel channel);

    void channelRead(Channel channel, Object msg);

    void onException(Channel channel, Throwable throwable);

    void onIdleStateEvent(Channel channel, IdleState state);


}

(2)netty配置(config)
package com.zzc.netty.infrastructure.netty.config;
/**
 * netty配置接口
 */
public interface SocketConfig {

    public String getIp();

    public void setIp(String ip);

    public int getPort();

    public void setPort(int port);

    public long getConnectTimeout();

    public void setConnectTimeout(long connectTimeout);

    public long getWriteTimeout();

    public void setWriteTimeout(long writeTimeout);

    public long getCloseTimeout();

    public void setCloseTimeout(long closeTimeout);

}

package com.zzc.netty.infrastructure.netty.config;


/**
 * 基础配置接口实现,后续如果协议拓展需要其他配置则进行继承
 */
public class BaseSocketConfig implements SocketConfig {

    long TIMEOUT_CONNECT = 30 * 1000;

    long TIMEOUT_WRITE = 30 * 1000;

    long TIMEOUT_CLOSE = 120 * 1000;

    private String ip;

    private int port = 9696;

    private long connectTimeout = TIMEOUT_CONNECT;

    private long writeTimeout = TIMEOUT_WRITE;

    private long closeTimeout = TIMEOUT_CLOSE;

    @Override
    public String getIp() {

        return ip;
    }

    @Override
    public void setIp(String ip) {
        this.ip = ip;
    }

    @Override
    public int getPort() {
        return port;
    }

    @Override
    public void setPort(int port) {
        this.port = port;
    }

    @Override
    public long getConnectTimeout() {
        return connectTimeout;
    }

    @Override
    public void setConnectTimeout(long connectTimeout) {
        this.connectTimeout = connectTimeout;
    }

    @Override
    public long getWriteTimeout() {
        return writeTimeout;
    }

    @Override
    public void setWriteTimeout(long writeTimeout) {
        this.writeTimeout = writeTimeout;
    }

    @Override
    public long getCloseTimeout() {
        return closeTimeout;
    }

    @Override
    public void setCloseTimeout(long closeTimeout) {
        this.closeTimeout = closeTimeout;
    }
}
(3)链接接口和封装(conn)
channel基础操作接口封装
package com.zzc.netty.infrastructure.netty.conn;

public interface Conn {

    //boolean isServer();

    boolean isOpen();

    void setAllowWrite(boolean allowWrite);

    boolean isAllowWrite();

    String getConnId();

    void delayRelease();

    void delayRelease(int delayMilliSecond);

    void release();

    void releaseGracefully(Object statusCode);

    void releaseGracefully();

    boolean writeObj(Object msg);


    void writeObjAsyn(Object msg);


    void writeObjAsyn(Object msg, ConnFutureListener listener);


    void writePing();

    void writePong();

}

/**
 * 监听channel操作结果
 */
public interface ConnFutureListener {

    void onSuccess(Conn conn);

    void onCancel();

    void onFailed();

}

websocket能力接口的封装
package com.zzc.netty.infrastructure.netty.conn;

import com.zzc.netty.domain.protocol.Response;

public interface WebSocketConn extends Conn {

    boolean writeResp(Response response);

    void writeRespAsyn(Response response);

}

package com.zzc.netty.infrastructure.netty.conn;

import com.zzc.netty.domain.protocol.Response;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class WebSocketConnImpl extends BaseConn implements WebSocketConn {
    public WebSocketConnImpl(Channel channel) {
        super(channel);
    }


    @Override
    public boolean writeResp(Response response) {
        TextWebSocketFrame frame = new TextWebSocketFrame(response.toString());
        return writeObj(frame);
    }

    @Override
    public void writeRespAsyn(Response response) {
        TextWebSocketFrame frame = new TextWebSocketFrame(response.toString());
        writeObjAsyn(frame);
    }
}

(4)handler实现(handler)
安全校验handler,在握手之前实现
package com.zzc.netty.infrastructure.netty.handler;

import com.zzc.netty.infracore.common.utils.HttpxUtils;
import com.zzc.netty.infrastructure.netty.WebSocketChannelHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;

@Slf4j
@ChannelHandler.Sharable
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {

    public static final AttributeKey<Map<String, Object>> SECURITY_CHECK_ATTRIBUTE_KEY =
            AttributeKey.valueOf("SECURITY_CHECK_ATTRIBUTE_KEY");

    private WebSocketChannelHandler webSocketChannelHandler;

    public SecurityServerHandler(WebSocketChannelHandler webSocketChannelHandler) {
        this.webSocketChannelHandler = webSocketChannelHandler;
    }

    /**
     * 经过测试,在 ws 的 uri 后面不能传递参数,不然在 netty 实现 websocket 协议握手的时候会出现断开连接的情况。
     * 针对这种情况在 websocketHandler 之前做了一层 地址过滤,然后重写
     * request 的 uri,并传入下一个管道中,基本上解决了这个问题。
     * TODO 其他方式,就是重写握手的流程
     * @param ctx
     * @param msg
     * @throws Exception
     */
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof FullHttpRequest) {
            //签名校验
            FullHttpRequest request = (FullHttpRequest) msg;
            Map<String, Object> params = HttpxUtils.urlQueryParams(request.uri());//解析uri中的参数
            boolean valid = webSocketChannelHandler.beforeHandshake(ctx.channel(), params);//握手前校验链接是否通过
            log.info("channelRead valid:{}", valid);
            if (valid) {//校验通过则使用netty的事件发布
                request.setUri("/ws");//需要覆盖uri的参数,否则后续的握手包处有问题
                ctx.channel().attr(SECURITY_CHECK_ATTRIBUTE_KEY).set(params);
                ctx.fireUserEventTriggered(params);
                ctx.pipeline().remove(this);
            } else {
                ctx.close();
                return;
            }
            super.channelRead(ctx, msg);
        }

    }

}
websocket消息处理
package com.zzc.netty.infrastructure.netty.handler;

import com.zzc.netty.infrastructure.netty.WebSocketChannelHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;

@Slf4j
@ChannelHandler.Sharable
public class WebSocketFrameHandler extends SimpleChannelInboundHandler {

    private WebSocketChannelHandler webSocketChannelHandler;

    public WebSocketFrameHandler(WebSocketChannelHandler webSocketChannelHandler) {
        this.webSocketChannelHandler = webSocketChannelHandler;
    }

    @Override
    public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
        super.channelRegistered(ctx);
    }

    @Override
    public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
        super.channelUnregistered(ctx);
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        webSocketChannelHandler.channelActive(ctx.channel());
        super.channelActive(ctx);
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        webSocketChannelHandler.channelInactive(ctx.channel());
        super.channelInactive(ctx);
    }

    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        super.channelReadComplete(ctx);
    }

    @Override
    public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
        super.channelWritabilityChanged(ctx);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        webSocketChannelHandler.onException(ctx.channel(), cause);
        super.exceptionCaught(ctx, cause);
    }


    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
         if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
            Map<String, Object> params = ctx.channel().attr(SecurityServerHandler.SECURITY_CHECK_ATTRIBUTE_KEY).get();
            log.info("userEventTriggered params:{}", params);
            webSocketChannelHandler.afterHandshake(ctx.channel(), params);
        }
        super.userEventTriggered(ctx, evt);
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
        webSocketChannelHandler.channelRead(ctx.channel(), msg);
    }

}
(5)服务实现(server)
server配置
package com.zzc.netty.infrastructure.netty.server;


import com.zzc.netty.infrastructure.netty.config.BaseSocketConfig;

/**
 * server配置,实际上 ServerSocketConfig 应该定义为接口,然后在实现,方便拓展
 */
public class ServerSocketConfig extends BaseSocketConfig {

    private int bossThreads = 2;

    private int workThreads = Runtime.getRuntime().availableProcessors() * 2;

    public ServerSocketConfig() {
        super();
    }

    public ServerSocketConfig(int bossThreads, int workThreads) {
        this.bossThreads = bossThreads;
        this.workThreads = workThreads;
    }

    public int getBossThreads() {
        return bossThreads;
    }

    public int getWorkThreads() {
        return workThreads;
    }

}

socket功能实现
package com.zzc.netty.infrastructure.netty.server;


import com.zzc.netty.infrastructure.netty.Socket;

public interface ServerSocket extends Socket<ServerSocketConfig> {

}

package com.zzc.netty.infrastructure.netty.server;

import com.zzc.netty.infracore.common.utils.HttpxUtils;
import com.zzc.netty.infracore.common.utils.ThreadPoolUtils;
import com.zzc.netty.infrastructure.netty.BaseSocket;
import com.zzc.netty.infrastructure.netty.handler.SecurityServerHandler;
import com.zzc.netty.infrastructure.netty.handler.WebSocketFrameHandler;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.AdaptiveRecvByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;
import java.util.concurrent.TimeUnit;


@Slf4j
public class ServerSocketImpl extends BaseSocket<ServerSocketConfig> implements ServerSocket {

    private final static String THREAD_PREFIX_BOSS = "websocketBossServer";

    private final static String THREAD_PREFIX_WORK = "websocketWorkServer";

    ServerBootstrap bootstrap = null;

    EventLoopGroup bossGroup = null;

    EventLoopGroup workGroup = null;

    public ServerSocketImpl() {
        super(true);
    }

    @Override
    protected boolean doStart(ServerSocketConfig serverConfig) {
        long currentTimeMillis = System.currentTimeMillis();
        boolean listenResuset = false;
        try {
            bootstrap = new ServerBootstrap();
            bossGroup = new NioEventLoopGroup(serverConfig.getBossThreads(), ThreadPoolUtils.newThreadFactory(THREAD_PREFIX_BOSS));
            workGroup = new NioEventLoopGroup(serverConfig.getWorkThreads(), ThreadPoolUtils.newThreadFactory(THREAD_PREFIX_WORK));
            bootstrap.group(bossGroup, workGroup)
                    .channel(NioServerSocketChannel.class)
                    .childOption(ChannelOption.TCP_NODELAY, true)
                    .childOption(ChannelOption.SO_REUSEADDR, true)
                    .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                    .childOption(ChannelOption.RCVBUF_ALLOCATOR, new AdaptiveRecvByteBufAllocator(64, 1024, 65535))
                    .childHandler(new ChannelInitializer<NioSocketChannel>() {
                        @Override
                        protected void initChannel(NioSocketChannel channel) throws Exception {
                            ChannelPipeline pipeline = channel.pipeline();
                            addNettyIdleHandler(pipeline);
                            addNettyOtherHandler(pipeline);
                        }
                    });
            listenResuset = bootstrap.bind(serverConfig.getPort()).await(serverConfig.getConnectTimeout(), TimeUnit.MICROSECONDS);
        } catch (Exception e) {
            log.error("listen server timeout.", e);
            return false;
        } finally {
            if (listenResuset) {
                log.info("listen server, result:{}, spendTime:{}", listenResuset, (System.currentTimeMillis() - currentTimeMillis));
            } else {
                log.error("listen server error, result:{}, spendTime:{}", listenResuset, (System.currentTimeMillis() - currentTimeMillis));
            }
        }
        return listenResuset;
    }

    @Override
    protected void doClose() {
        if (workGroup != null) {
            workGroup.shutdownGracefully();
        }
        if (bossGroup != null) {
            bossGroup.shutdownGracefully();
        }
    }

    @Override
    protected void addNettyOtherHandler(ChannelPipeline pipeline) {
        pipeline.addLast("http-codec", new HttpServerCodec());
        pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
        pipeline.addLast("http-chunked", new ChunkedWriteHandler());
        pipeline.addLast("security-handler", new SecurityServerHandler(getWebSocketChannelHandler()));
        pipeline.addLast("websocket-compression", new WebSocketServerCompressionHandler());//websocket数据压缩
        pipeline.addLast("handler", new WebSocketServerProtocolHandler("/ws", null, true, 1024 * 1024, true));
        pipeline.addLast("websocket-handler", new WebSocketFrameHandler(getWebSocketChannelHandler()));

    }
}

2.应用层实现(命令模式)

(1)service实现
消息(指令)接收者抽象实现(Receive接收者角色)
package com.zzc.netty.application.service;

import com.zzc.netty.domain.Constant;
import com.zzc.netty.infracore.common.utils.JWTUtils;
import com.zzc.netty.infracore.common.utils.StrUtils;
import com.zzc.netty.infrastructure.netty.WebSocketChannelHandler;
import com.zzc.netty.infrastructure.netty.conn.WebSocketConn;
import com.zzc.netty.infrastructure.netty.conn.WebSocketConnImpl;
import com.zzc.netty.domain.enums.DisconReason;
import io.netty.channel.Channel;
import io.netty.handler.timeout.IdleState;
import lombok.extern.slf4j.Slf4j;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public abstract class AbstractWebSocketChannelHandler implements WebSocketChannelHandler {

    private static Map<String, WebSocketConn> children = new ConcurrentHashMap<>();

    @Override
    public boolean beforeHandshake(Channel channel, Map<String, Object> params) {
        if (params == null) {
            return false;
        }
        String appId = (String) params.get(Constant.KEY_APP_ID);
        String userId = (String) params.get(Constant.KEY_USERID);
        String username = (String) params.get(Constant.KEY_USERNAME);
        String platform = (String) params.get(Constant.KEY_PLATFORM);
        String token = (String) params.get(Constant.KEY_TOKEN);
        String ts = (String) params.get(Constant.KEY_TS);
        if (StrUtils.isBlank(appId, userId, username, platform, token, ts)) {
            log.info("beforeHandshake param exist null. appId:{}, userId:{}, username:{}, platform:{}, ts:{}", appId, userId, username, platform, ts);
            return false;
        }
        boolean verify = JWTUtils.verify(token);//TODO
        return verify;
    }

    @Override
    public boolean afterHandshake(Channel channel, Map<String, Object> params) {
        log.info("afterHandshake params:{}", params);
        String connId = getConnId(channel);
        WebSocketConn conn = new WebSocketConnImpl(channel);
        boolean connected = connected(conn, params);
        //
        if (connected) {
            addConn(connId, conn);
        }
        return connected;
    }

    @Override
    public void channelActive(Channel channel) {
        log.info("channelActive connId:{}", getConnId(channel));
    }

    @Override
    public void channelInactive(Channel channel) {
        String connId = getConnId(channel);
        WebSocketConn conn = getConn(connId);
        if (conn != null) {
            log.info("remove conn. connId:{}", connId);
            try {
                disconnected(conn, DisconReason.NORMAL);//TODO
            } finally {
                ensureRelease(conn, channel);
            }
        }
    }

    @Override
    public void channelRead(Channel channel, Object msg) {
        String connId = getConnId(channel);
        WebSocketConn conn = getConn(connId);
        if (conn == null) {
            log.error("channelRead error, local cache is null. connId:{}, msg:{}", connId, msg);
            return;
        }
        receiver(conn, msg);
    }

    @Override
    public void onException(Channel channel, Throwable throwable) {
        String connId = getConnId(channel);
        WebSocketConn conn = getConn(connId);
        log.info("onException connId:{}", connId);
        try {
            disconnected(conn, DisconReason.EXCEPTION);
        } finally {
            ensureRelease(conn, channel);
        }
    }

    @Override
    public void onIdleStateEvent(Channel channel, IdleState state) {
        if (state == IdleState.WRITER_IDLE) {

        } else if (state == IdleState.READER_IDLE) {

        } else if (state == IdleState.ALL_IDLE) {
            // 太长时间无收发消息,一般要做断开连接
            String connId = getConnId(channel);
            WebSocketConn conn = getConn(connId);
            log.info("onIdleStateEvent connId:{}", connId);
            try {
                disconnected(conn, DisconReason.IDLE);
            } finally {
                ensureRelease(conn, channel);
            }
        }

    }

    private String getConnId(Channel channel) {
        if (channel != null) {
            return channel.id().asLongText();
        }
        return null;
    }


    private WebSocketConn getConn(String connId) {
        return children.get(connId);
    }

    private void addConn(String connId, WebSocketConn conn) {
        children.put(connId, conn);
    }

    private void removeConn(String connId) {
        children.remove(connId);
    }

    private void ensureRelease(WebSocketConn conn, Channel channel) {
        if (conn == null) {
            channel.close();
            log.warn("ensureRelease conn is null.");
            return;
        }
        try {
            conn.release();
        } finally {
            removeConn(conn.getConnId());
        }
    }

    /**
     * 握手成功后,业务层处理,比如保存用户会话到redis
     * @param conn
     * @param params
     */
    protected abstract boolean connected(WebSocketConn conn, Map<String, Object> params);


    /**
     * ws链接断开,内部清除完本地缓存之后,业务层处理
     */
    protected abstract void disconnected(WebSocketConn conn, DisconReason reason);


    /**
     * 链接正常,收到客户端的消息
     * @param conn
     * @param msg
     */
    protected abstract void receiver(WebSocketConn conn, Object msg);
}

定义指令调用者(Invoker调用者角色)
package com.zzc.netty.application.service;

import com.alibaba.fastjson2.JSON;
import com.zzc.netty.application.CommandFactory;
import com.zzc.netty.domain.command.CommandContext;
import com.zzc.netty.application.CommandHandler;
import com.zzc.netty.infracore.api.CommonCode;
import com.zzc.netty.domain.protocol.Response;
import com.zzc.netty.infracore.common.utils.ThreadPoolUtils;
import com.zzc.netty.infrastructure.netty.conn.WebSocketConn;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

@Slf4j
@Service
public class CommandInvoker {

    private static ThreadPoolExecutor executor = ThreadPoolUtils.newThreadPoolExecutorDirectAndAsy(
            "command-handler",
            2 * Runtime.getRuntime().availableProcessors(),
            4 * Runtime.getRuntime().availableProcessors(),
            120,
            TimeUnit.SECONDS,
            30);

    public void action(WebSocketConn conn, String jsonStr) {
        executor.submit(new Runnable() {
            @Override
            public void run() {
                try {
                    CommandContext ctx = JSON.parseObject(jsonStr, CommandContext.class);
                    ctx.setResultCode(CommonCode.SUCCESS);
                    log.info("action ctx:{}, resultCode:{}", JSON.toJSONString(ctx), JSON.toJSONString(ctx.getResultCode()));

                    ctx.setWebSocketConn(conn);
                    String command = ctx.getCommand();
                    CommandHandler handler = CommandFactory.getHandler(command);
                    handler.execute(ctx);
                } catch (Exception e) {
                    conn.writeRespAsyn(Response.error(CommonCode.SYSTEM_ERROR));
                    log.error("receiver cmd error.", e);
                }
            }
        });

    }

}

命令抽象实现(命令角色)
package com.zzc.netty.application;

import com.zzc.netty.domain.command.CommandContext;
import com.zzc.netty.infracore.api.CommonCode;
import com.zzc.netty.infracore.exception.DomainException;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public abstract class CommandHandler<T> {

    enum State {

        ACK,

        BEFORE,

        PROCESS,

        AFTER,

        FINISH
    }

    protected abstract boolean ack(CommandContext<T> ctx);

    protected abstract boolean beforeHandler(CommandContext<T> ctx);

    protected abstract boolean handler(CommandContext<T> ctx);

    protected abstract boolean afterHandler(CommandContext<T> ctx);

    protected abstract boolean answer(CommandContext<T> ctx);

    public void execute(CommandContext<T> ctx) {
        State state = State.ACK;
        switch (state) {
            case ACK:
                if (!ack(ctx)) {
                    break;
                }
                state = State.BEFORE;
            case BEFORE:
                if (!beforeHandler(ctx)) {
                    answer(ctx);
                    break;
                }
                state = State.PROCESS;
            case PROCESS:
                if (!handler(ctx)) {
                    answer(ctx);
                    break;
                }
                state = State.AFTER;
            case AFTER:
                if (!afterHandler(ctx)) {
                    answer(ctx);
                    break;
                }
                state = State.FINISH;
            case FINISH:
                answer(ctx);
                return;
            default:
                throw new DomainException(CommonCode.SYSTEM_ERROR);
        }
    }

}

消息或者指令实现(Receive接收者角色和高层模块应用)
package com.zzc.netty.application.service;

import com.zzc.netty.application.dto.CloseWsCommand;
import com.zzc.netty.domain.Constant;
import com.zzc.netty.domain.protocol.ProtocolFactory;
import com.zzc.netty.domain.command.CommandEnums;
import com.zzc.netty.domain.session.UserSession;
import com.zzc.netty.domain.session.UserSessionService;
import com.zzc.netty.infracore.api.CommonCode;
import com.zzc.netty.domain.protocol.Response;
import com.zzc.netty.infrastructure.netty.conn.WebSocketConn;
import com.zzc.netty.domain.enums.DisconReason;
import com.zzc.netty.infrastructure.netty.server.ServerSocketConfig;
import com.zzc.netty.infrastructure.netty.server.ServerSocketImpl;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.Map;

@Slf4j
@Service
public class SignalServiceImpl extends AbstractWebSocketChannelHandler {

    @Autowired
    private UserSessionService userSessionService;

    @Autowired
    private CommandInvoker commandInvoker;

    @PostConstruct
    public void init() {
        ServerSocketImpl serverSocket = new ServerSocketImpl();
        serverSocket.start(new ServerSocketConfig(), this);
    }

    @Override
    protected boolean connected(WebSocketConn conn, Map<String, Object> params) {
        String sessionId = conn.getConnId();
        if (!conn.isOpen()) {
            throw new RuntimeException("connected conn is close. connId:" + sessionId);
        }
        if (params == null || params.isEmpty()) {
            conn.writeRespAsyn(Response.error(CommonCode.PARAM_NULL));
            return false;
        }

        String appId = (String) params.get(Constant.KEY_APP_ID);
        String userId = (String) params.get(Constant.KEY_USERID);
        String platform = (String) params.get(Constant.KEY_PLATFORM);
        String ts = (String) params.get(Constant.KEY_TS);
        String username = (String) params.get(Constant.KEY_USERNAME);

        //TODO 判断是否重连
        UserSession userSession = UserSession.Builder()
                .appId(appId)
                .sessionId(sessionId)
                .userId(userId)
                .platform(platform)
                .username(username)
                .ts(ts)
                .build();

        userSessionService.addUserSession(userSession);
        return conn.writeResp(ProtocolFactory.createResp(CommandEnums.OPENWS.getCommand(), sessionId, ts));
    }

    @Override
    protected void disconnected(WebSocketConn conn, DisconReason reason) {
        if (conn == null) {
            log.error("disconnected error, conn is null.");
            return;
        }
        log.info("disconnected connId:{}", conn.getConnId());
        UserSession userSession = userSessionService.removeUserSession(conn.getConnId());
        if (userSession != null) {
            CloseWsCommand closeWs = new CloseWsCommand();
            closeWs.setReason(reason.getReason());
            conn.writeRespAsyn(ProtocolFactory.createResp(closeWs, CommandEnums.CLOSEWS.getCommand(), userSession.getSessionId(), userSession.getTs()));
        } else {
            log.warn("usersession is null. connId:{}", conn.getConnId());
        }
    }

    @Override
    protected void receiver(WebSocketConn conn, Object msg) {
        if (msg instanceof TextWebSocketFrame) {
            String jsonData = ((TextWebSocketFrame) msg).text();
            log.info("receiver rev msg:{}", jsonData);
            commandInvoker.action(conn, jsonData);
        } else {
            log.warn("receiver other data. class:{}", msg.getClass());
        }
    }

}

指令工厂
package com.zzc.netty.application;

import com.zzc.netty.adapter.handler.AddRoomHandler;
import com.zzc.netty.domain.command.CommandEnums;
import com.zzc.netty.infracore.common.utils.SpringBeansUtil;

import java.util.HashMap;
import java.util.Map;

public class CommandFactory {

    private static final Map<String, CommandHandler> commandHandlers = new HashMap<>();


    static {
    	//继承CommandHandler实习指令后均可添加到此处,也可以使用spring的IOC注入后获取bean;
    	//但是不能使用@Autowired获取,因为都是继承CommandHandler实现的
        commandHandlers.put(CommandEnums.ADDROOM.getCommand(), new AddRoomHandler());
    }

    public static CommandHandler getHandler(String command) {
        return commandHandlers.get(command);
    }

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值