文章目录
前言
在项目的开发过程中,有很多场景都需要使用消息提醒功能,例如上架提醒,维护提醒,留言提醒等等。在此需求背景下选型netty搭建websocket,来实现消息推送提醒。
一、Netty基本架构
上图是网上找的Netty架构概念图,大概描述了Netty的工作流程。
二、项目结构与具体实现
1.引入核心Netty依赖
代码如下:
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.51.Final</version>
</dependency>
其他JSON解析,Redis等等依赖按需引入。
2.核心代码实现
启动类
/**
* @author hizoo
*/
public class NioWebSocketServer {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
public void init() {
logger.info("正在启动websocket服务器");
NioEventLoopGroup boss = new NioEventLoopGroup();
NioEventLoopGroup work = new NioEventLoopGroup();
try {
ServerBootstrap bootstrap = new ServerBootstrap();
bootstrap.group(boss, work);
bootstrap.channel(NioServerSocketChannel.class);
//自定义处理器
bootstrap.childHandler(new NioWebSocketChannelInitializer());
Channel channel = bootstrap.bind(8083).sync().channel();
logger.info("webSocket服务器启动成功:" + channel);
channel.closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
logger.info("运行出错:" + e);
} finally {
boss.shutdownGracefully();
work.shutdownGracefully();
logger.info("websocket服务器已关闭");
}
}
public static void main(String[] args) {
new NioWebSocketServer().init();
}
}
netty搭建服务器基本流程,绑定主线程组和工作线程组,这部分对应架构图中的事件循环组,只有服务器才需要绑定端口,客户端是绑定一个地址,最重要的是ChannelInitializer配置channel(数据通道)参数,ChannelInitializer以异步的方式启动,最后是结束关闭两个线程组。
channel初始化
/**
* @author hizoo
*/
public class NioWebSocketChannelInitializer extends ChannelInitializer<SocketChannel> {
@SneakyThrows
@Override
protected void initChannel(SocketChannel ch) {
//设置log监听器
ch.pipeline().addLast("logging",new LoggingHandler("INFO"));
//设置解码器
ch.pipeline().addLast("http-codec",new HttpServerCodec());
//聚合器,使用websocket会用到
ch.pipeline().addLast("aggregator",new HttpObjectAggregator(65536));
//用于大数据的分区传输
ch.pipeline().addLast("http-chunked",new ChunkedWriteHandler());
//自定义的业务handler
ch.pipeline().addLast("handler",new NioWebSocketHandler());
}
}
自定义的业务handler
/**
* @author hizoo
*/
public class NioWebSocketHandler extends SimpleChannelInboundHandler<Object> {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private WebSocketServerHandshaker handshaker;
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof FullHttpRequest) {
//以http请求形式接入,但是走的是websocket
handleHttpRequest(ctx, (FullHttpRequest) msg);
} else if (msg instanceof WebSocketFrame) {
//处理websocket客户端的消息
handlerWebSocketFrame(ctx, (WebSocketFrame) msg);
}
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
//添加连接
ChannelSupervise.addChannel(ctx.channel());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
//断开连接
ChannelSupervise.removeChannel(ctx.channel());
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
ctx.flush();
}
private void handlerWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame){
try {
// 判断是否关闭链路的指令
if (frame instanceof CloseWebSocketFrame) {
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
return;
}
// 判断是否ping消息
if (frame instanceof PingWebSocketFrame) {
ctx.channel().write(
new PongWebSocketFrame(frame.content().retain()));
return;
}
// 本例程仅支持文本消息,不支持二进制消息
if (!(frame instanceof TextWebSocketFrame)) {
throw new UnsupportedOperationException(String.format(
"%s frame types not supported", frame.getClass().getName()));
}
// 返回应答消息
String request = ((TextWebSocketFrame) frame).text();
//获取当前登录用户信息
Map<String, String> user = (Map<String, String>) JSON.parse(request);
String token = user.get("token");
String type = user.get("type");
JWSObject jwsObject = JWSObject.parse(token);
String payload = jwsObject.getPayload().toString();
JSONObject jsonObject = JSONUtil.parseObj(payload);
String userId = jsonObject.get("user_id").toString();
logger.debug("服务端收到:" + request);
TextWebSocketFrame tws = new TextWebSocketFrame("xxxxxx");
ChannelSupervise.addUserChannel(type.concat(userId), ctx.channel().id().asShortText());
// 返回【谁发的发给谁】
ctx.channel().writeAndFlush(tws);
}catch (Exception e) {
e.printStackTrace();
}
}
/**
* 唯一的一次http请求,用于创建websocket
*/
private void handleHttpRequest(ChannelHandlerContext ctx,
FullHttpRequest req) {
//要求Upgrade为websocket,过滤掉get/Post
if (!req.decoderResult().isSuccess()
|| (!"websocket".equals(req.headers().get("Upgrade")))) {
//若不是websocket方式,则创建BAD_REQUEST的req,返回给客户端
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST));
return;
}
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
"ws://localhost:8083/websocket/*", null, false);
handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
handshaker.handshake(ctx.channel(), req);
}
}
/**
* 拒绝不合法的请求,并返回错误信息
*/
private static void sendHttpResponse(ChannelHandlerContext ctx,
FullHttpRequest req, DefaultFullHttpResponse res) {
// 返回应答给客户端
if (res.status().code() != 200) {
ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(),
CharsetUtil.UTF_8);
res.content().writeBytes(buf);
buf.release();
}
ChannelFuture f = ctx.channel().writeAndFlush(res);
// 如果是非Keep-Alive,关闭连接
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
}
保存客户端的信息
/**
* @author hizoo
*/
public class ChannelSupervise {
private static ChannelGroup GlobalGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
private static ConcurrentMap<String, ChannelId> ChannelMap = new ConcurrentHashMap();
private static ConcurrentMap<String, String> userChannelMap = new ConcurrentHashMap();
public static void addChannel(Channel channel) {
GlobalGroup.add(channel);
ChannelMap.put(channel.id().asShortText(), channel.id());
}
public static void addUserChannel(String userId, String channelId) throws UnknownHostException {
userChannelMap.put(userId, channelId);
}
public static void removeChannel(Channel channel) {
AtomicReference<String> removeKey = new AtomicReference<>();
GlobalGroup.remove(channel);
ChannelMap.remove(channel.id().asShortText());
userChannelMap.forEach((key, value) -> {
if (value.equals(channel.id().asShortText())) {
removeKey.set(key);
}
});
if (!StringUtil.isNullOrEmpty(removeKey.get())) {
userChannelMap.remove(removeKey.get());
}
}
public static Channel findChannelByUserId(String userid) {
if (!userChannelMap.containsKey(userid)) {
return null;
}
String channelId = userChannelMap.get(userid);
return GlobalGroup.find(ChannelMap.get(channelId));
}
}
实际调用:判断消息对象中的用户id列表是否为连接状态,如果存在则发送消息通知
/**
* 消息推送
*/
public void pushMessage(@RequestBody MessageDTO messageDTO) {
List<Channel> channels = new LinkedList<>();
String type = messageDTO.getUserType();
List<String> userIds = messageDTO.getUserIds();
Long id = messageService.insertMsg(userIds, messageDTO);
messageDTO.setId(id);
messageDTO.setReadFlag(false);
if (userIds.size() > 0) {
userIds.forEach((userId) -> {
Channel channel = ChannelSupervise.findChannelByUserId(type.concat(userId));
if (channel != null) channels.add(channel);
});
}
if (channels.size() > 0) {
channels.forEach((channel) -> {
TextWebSocketFrame tws = new TextWebSocketFrame(JSON.toJSONString(messageDTO));
channel.writeAndFlush(tws);
});
}
}
问题
在最初开发时消息推送功能一切正常。但是上了测试环境后,发现消息推送经常丢失,一开始怀疑是网络原因,排查后发现网络没有问题。又去系统日志里面详细排查,最后发现问题指向了服务器集群。由于消息推送的channel相关信息都是存储在ConcurrentMap上,开发环境都是在我本地运行,消息的发送接收都发生在我自己的机器上,所以消息的发送接收没有问题。但是测试环境下是3台服务器做的集群,消息在推送的过程中只能找到与本机连接的用户,连接在其他两台服务器上的用户,是接收不了消息的。
改进
第一种方案:RabbitMQ实现。
引入RabbitMQ,使用RabbitMQ的发布订阅模式来广播所有的Netty服务,当我们需要发送的用户不在本服务中,就将本次消息通过MQ发布出去,每个Netty服务去消费这个message。但是由于我的系统体量较小,引入RabbitMQ会对系统的复杂性进一步的提高,整个系统变重,故不考虑此方案。
第二种方案:将消息进行转发。
通过Redis去存储用户信息与所连接服务器IP地址。每次发送消息的时候,将发给不同用户的消息转发到他们所在服务器上再去进行推送。此方案代码改动量相对较少,也没额外引入其他中间件,较为适合我的系统,故选择。详细代码如下:
public static void addUserChannel(String userId, String channelId) throws UnknownHostException {
InetAddress address = InetAddress.getLocalHost();
//将用户id与所在服务器IP地址存入redis
StaticRedisUtil.staticRedisUtil.set(netty+userId,address.getHostAddress());
userChannelMap.put(userId, channelId);
}
发送消息时,分为两部分,先获取用户连接的信息,再转发消息。
@Override
@Transactional
public void pushMessage(MessageDTO messageDTO) {
List<String> userIds = messageDTO.getUserIds();
Long id = messageService.insertMsg(userIds, messageDTO);
messageDTO.setId(id);
messageDTO.setReadFlag(false);
//构造用户ipMap与ip集合
Map<String,String> userIpMap = new HashMap<>();
Set<String> ipSet = new HashSet<>();
if (userIds.size() > 0) {
userIds.forEach((userId) -> {
messageDTO.setUserId(Long.valueOf(userId));
String url = redisUtils.get(netty+messageDTO.getUserType()+userId);
if(StringUtils.isNotBlank(url)){
userIpMap.put(userId,url);
ipSet.add(url);
}
});
}
Set<String> userSet = userIpMap.keySet();
RestTemplate restTemplate = new RestTemplate();
for (String ip : ipSet) {
List<String> users = new ArrayList<>();
for(String userId:userSet){
if(ip.equals(userIpMap.get(userId))){
users.add(userId);
}
}
messageDTO.setUserIds(users);
//拼接消息转发地址
String realUrl = "http://" + ip + ":" + this.serverPort+"/message/pushMessageServed";
restTemplate.postForObject(realUrl, messageDTO, String.class);
}
}
@Override
public void pushMessageServed(MessageDTO messageDTO) {
String type = messageDTO.getUserType();
List<String> userIds = messageDTO.getUserIds();
for (String userId : userIds) {
Channel channel = ChannelSupervise.findChannelByUserId(type.concat(userId));
if (channel != null) {
TextWebSocketFrame tws = new TextWebSocketFrame(JSON.toJSONString(messageDTO));
channel.writeAndFlush(tws);
}
}
}
总结
在项目开发的实际过程中,很多的问题解决办法可能有很多种,最终选择方案时一定要根据环境,因地制宜的选择最合适自己系统的一种,不需要盲目的追求新技术,多功能,给自己的系统带来很多不稳定的隐患,任何上线的系统首要一定是保证项目的稳定运行和功能的实现。