公司早些时候接入一款健康监测设备,由于业务原因近日把端口暴露在公网后,每当被恶意连接时系统会创建大量线程,在排查问题是发现是使用了厂家提供的服务端demo代码,在代码中使用的是java 原生socket,在发现连接后使用独立线程处理后续通信,占用系统资源造成了服务宕机,因此需要进行改造。
厂家提供的demo代码如下:
import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
import java.util.List;
public class Demo {
public static void main(String[] args) {
int port = 8003;
if (args.length == 1) {
port = Integer.parseInt(args[0]);
}
ServerSocket ss;
try {
ss = new ServerSocket(port);
}
catch (Exception e) {
System.out.println("服务端socket失败 port = " + port);
return;
}
System.out.println("启动socket监听 端口:" + port);
List<Socket> socketList = new ArrayList<>();
while (true) {
try {
Socket socket = ss.accept();
if (socket == null || socket.isClosed()) {
socketList.remove(socket);
continue;
}
if (socketList.contains(socket)) {
continue;
}
socketList.add(socket);
System.out.println("socket连接 address = " + socket.getInetAddress().toString() + " port = " + socket.getPort());
new Thread(new HealthReadThread(socket)).start();
}
catch (IOException e) {
System.out.println(e.getMessage());
}
}
}
}
import java.io.*;
import java.net.Socket;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
public class HealthReadThread implements Runnable {
private Socket socket;
HealthReadThread(Socket socket) {
this.socket = socket;
}
private static String message = "";
@Override
public void run() {
try {
//输入
InputStream inPutStream = socket.getInputStream();
BufferedInputStream bis = new BufferedInputStream(inPutStream);
// BufferedReader br = new BufferedReader(new InputStreamReader(inPutStream));
//输出
OutputStream outputStream = socket.getOutputStream();
BufferedOutputStream bw = new BufferedOutputStream(outputStream);
String ip = socket.getInetAddress().getHostAddress();
int port = socket.getPort();
String readStr = "";
// char[] buf;
byte[] buf;
int readLen = 0;
while (true) {
if (socket.isClosed()) {
break;
}
buf = new byte[1024];
try {
readLen = bis.read(buf);
if (readLen <= 0) {
// System.out.println(Thread.currentThread().getId() + "线程: " + "ip地址:" + ip + " 端口地址:" + port + "暂无接收数据");
continue;
}
System.out.println(Thread.currentThread().getId() + "线程: " + "ip地址:" + ip + " 端口地址:" + port + " 接收到原始命令长度:" + readLen);
readStr = StringUtils.byteToHexString(buf, readLen);
// readStr = new String(buf ,0 , readLen);
} catch (IOException e) {
System.out.println(e.getMessage());
socket.close();
// continue;
}
if (readStr == null || "".equals(readStr)) {
continue;
}
// 省略业务代码
}
}
catch (Exception e) {
System.out.println(e.getMessage());
}
}
}
使用netty进行改造:
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class DeviceNettyServer implements ApplicationRunner {
@Override
public void run(ApplicationArguments args) throws Exception {
start();
}
public void start() {
Thread thread = new Thread(() -> {
// 配置服务端的NIO线程组
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
EventLoopGroup workerGroup = new NioEventLoopGroup(4);
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
// 使用 NIO 方式进行网络通信
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
// 添加自己的处理器
ch.pipeline().addLast(new DeviceMsgHandler());
}
});
try {
int port1 = 8081;
int port2 = 8082;
// 绑定一个端口并且同步,生成一个ChannelFuture对象
ChannelFuture f1 = b.bind(port1).sync();
ChannelFuture f2 = b.bind(port2).sync();
log.info("启动监听, 端口:" + port1 + "、" + port2);
// 对关闭通道进行监听
f1.channel().closeFuture().sync();
f2.channel().closeFuture().sync();
} catch (Exception e) {
log.error("启动监听失败", e);
} finally {
workerGroup.shutdownGracefully();
bossGroup.shutdownGracefully();
}
});
thread.setName("DeviceNettyServer");
thread.start();
}
}
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class DeviceMsgHandler extends SimpleChannelInboundHandler<ByteBuf> {
/**
* 已连接的设备
*/
private static final ConcurrentHashMap<Channel, DeviceDTO> CONNECTION_DEVICE_MAP = new ConcurrentHashMap<>(8);
/**
* 一旦连接,第一个被执行
*/
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
String remoteAddress = ctx.channel().remoteAddress().toString();
log.info("发现连接, remoteAddress: " + remoteAddress);
// 发送查询设备信息指令
sendQuery(ctx.channel());
}
/**
* 读取数据
*/
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) {
byte[] bytes = new byte[msg.readableBytes()];
msg.readBytes(bytes);
// 忽略业务处理代码
// 传递给下一个处理器
ctx.fireChannelRead(msg);
}
/**
* 连接断开
*
* @param ctx
*/
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
log.info("连接断开, remoteAddress: " + ctx.channel().remoteAddress());
CONNECTION_DEVICE_MAP.remove(ctx.channel());
}
/**
* 连接异常
*
* @param ctx
* @param cause
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
log.info("连接异常, remoteAddress: " + ctx.channel().remoteAddress());
CONNECTION_DEVICE_MAP.remove(ctx.channel());
}
经过改造后使用了4个worker线程进行读写,消除了原先恶意连接造成线程数无线扩大的问题,使用nio也极大的提高了系统资源利用率。