如何基于java代理对支持udf功能的常用数据库返回结果进行敏感数据的脱敏

如何用一台服务器的代理端口转发数据库的所有流量并对相关操作进行记录,数据脱敏,修改或阻断

背景

一年结束了,又是新的一年开始了,针对前段时间研究的内容做一个简单的总结,上次写文章还是在去年了,写的内容是基于netty做的一个tcp端口动态代理的的工程,当时做这个工具的目的是为了解决两台服务器之间不能直接通信所以在一台两个服务都能访问的代理服务器上去转发流量的问题,但是这个工程就只做了流量的的转发,并没有对流量进行解析,那么这篇文章就基于上个工程做一个简单的扩展,对转发的流量进行解析,对危险的请求进行拦截,有风险的进行告警,对于我们不希望转发的流量可以进行拦截或修改成我们希望转发的流量。

原理

这里选择的解析流量的内容是基于数据库sql语句的解析,然后对解析的sql进行修改,新的sql语句组装成新的报文发送给数据库服务器,简单来说就是客户端发送的sql语句是(select name,phone from user ) 经过我们代理服务器进行加工后可以让服务器收到新的sql语句为(select name,phone from user limit 10)(这个操作就可以有效防止有人恶意攻击数据库一直查询大量数据导致数据库繁忙),对客户端和服务端都是无感知的从而达到防止攻击的目的,当然这里替换的功能很多,可以做很多事情,其目的并不是只是加一个限制行数而已。

说明

这只是一个验证猜想的实验产品,重心都放在解析流量上了,所以代码就是写流水先实现功能便于排查问题的,优化空间还很大后续再处理。

mysql实现

接下来将对数据库的报文进行解析:

默认处理方式


/**
 * @description: 默认的处理方式,不做任何处理
 * @author: yx
 * @date: 2021/12/8 10:20
 *
 * <p>
 */
@Slf4j
public class DefaultSqlParser {

    //默认处理方式,对任何数据都不做处理,直接转发
    public void dealChannel(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, Object msg) {
        channel.writeAndFlush(msg);
    }

    /**
     * 可以对删除语句自行做控制,这里只做日志记录
     *
     * @param ctx
     * @param config
     * @param channel
     * @param sql
     */
    void delete(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, String sql) {
        InetSocketAddress inetSocketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        log.info("{}主机在{}上执行了删除语句:{}", inetSocketAddress.getAddress(), config.getRemoteAddr(), sql);
    }

    /**
     * 可以对新增语句自行做权限控制或拦截,这里只做日志记录
     *
     * @param ctx
     * @param config
     * @param channel
     * @param sql
     */
    void insert(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, String sql) {
        InetSocketAddress inetSocketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        log.info("{}主机在{}上执行了新增语句:{}", inetSocketAddress.getAddress(), config.getRemoteAddr(), sql);
    }

    /**
     * 可以对修改语句自行做控制,检验或拦截,这里只做日志记录
     *
     * @param ctx
     * @param config
     * @param channel
     * @param sql
     */
    void update(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, String sql) {
        InetSocketAddress inetSocketAddress = (InetSocketAddress) ctx.channel().remoteAddress();
        log.info("{}主机在{}上执行了修改语句:{}", inetSocketAddress.getAddress(), config.getRemoteAddr(), sql);
    }
}

mysql数据包结构组成

在这里插入图片描述

代码解析


/**
 * @description: some desc
 * @author: yx
 * https://github.com/MyCATApache/Mycat-Server.git
 * mysql的数据包格式:数据长度标识,前3位+分割为第四位固定为0,第五位表示操作动作(增删查改...)+最后一位据0结束符,
 * 据网上资料说明数据包最大长度为16M最后一位是存储下一个字节长度的,暂不验证,场景暂时用不到,不处理暂时不收影响
 * 前三位算法:1=字符长度与0xff(255)进行与运算,
 * 2=字符串长度右移8位
 * 3=字符串长度右移16位
 * Buffer.writeLongInt
 * byte[] b = this.byteBuffer;
 * b[this.position++] = (byte) (i & 0xff);
 * b[this.position++] = (byte) (i >>> 8);
 * b[this.position++] = (byte) (i >>> 16);
 * 另外还涉及到连接方的连接编码方式,会在创建连接的时候有个:SET NAMES utf8mb4类似的命令,截取出来就可以知道对方的编码是什么格式,默认是gbk
 * @date: 2021/12/7 15:40
 */
@Slf4j
public class MySqlParser extends DefaultSqlParser {
    String rule = "concat(SUBSTR(#field#,1,CHAR_LENGTH(#field#)/2),substr('*************',CHAR_LENGTH(#field#)/2,CHAR_LENGTH(#field#)/2)) as #field#";
    public static Charset defaultCharset = Charset.forName("gbk");
    Map<String, ByteBuf> bufferMap = new HashMap();

    public void dealChannel(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, Object msg) {

        ByteBuf readBuffer = (ByteBuf) msg;
        //如果是服务端发送的消息远程地址为空
        InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
        String hostString = remoteAddress.getHostString();
        int port = remoteAddress.getPort();
        //只有发送给数据库的数据才需要进行处理
        int readableBytes = readBuffer.readableBytes();
        if (hostString.equals(config.getRemoteAddr()) && Objects.equals(port, config.getRemotePort())) {
            //第一步先获取会话的id,如果当前会话的pid没有被结束则直接把所有的数据写入到缓冲区buffer里面
            String localPid = channel.localAddress().toString();
            if (bufferMap.containsKey(localPid)) {
                ByteBuf byteBuf = bufferMap.get(localPid);
                //如果写入完全了则直接进行sql解析
                int index = readBuffer.writerIndex();
                byte[] tmpBytes = new byte[index];
                readBuffer.getBytes(0, tmpBytes);
                byteBuf.writeBytes(tmpBytes);
                if (byteBuf.writerIndex() == byteBuf.capacity()) {
                    dealBytes(ctx, config, channel, byteBuf);
                }
            } else {
                byte[] preData = new byte[5];  //处理客户端发送的消息
                readBuffer.getBytes(0, preData);
                //提前获取所有字节内容
                int allDataLength = getDataLength(preData);
                //如果当前缓冲区的数据与标致位的长度一致则直接处理数据
                if (allDataLength + 4 == readableBytes) {
                    byte preDatum = preData[4];
                    switch (preDatum) {
                        case MySQLPacket.COM_QUERY:
                            dealBytes(ctx, config, channel, readBuffer);
                            break;
                        default:
                            readBuffer.retain();
                            channel.writeAndFlush(readBuffer);
                            break;
                    }

                } else {
                    //说明数据包不完全,先继续接收数据包等接收完全后再处理sql
                    ByteBuf tmpBuffer = Unpooled.buffer(allDataLength + 4);
                    tmpBuffer.writeBytes(readBuffer);
                    bufferMap.put(localPid, tmpBuffer);
                }
            }
        } else {
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        }
    }

    /**
     * 处理完整数据包的字符内容
     * @param ctx
     * @param config
     * @param channel
     * @param byteBuf
     */
    private void dealBytes(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, ByteBuf byteBuf) {
        int readableBytes = byteBuf.readableBytes();

        byte[] datas = new byte[readableBytes - 5];
        byteBuf.getBytes(5, datas);
        Charset charset = Charset.defaultCharset();
        String localPid = channel.localAddress().toString();

        String sql = new String(datas, Optional.ofNullable(ThreadPublic.getCharset(localPid)).orElse(defaultCharset));
        //替换掉客户端自己生成的注释语句
        sql = sql.replaceAll("(?ms)('(?:''|[^'])*')|--.*?$|/\\*.*?\\*/", "$1").trim();
        bufferMap.remove(localPid);
        if (sql.toUpperCase(Locale.ROOT).startsWith("SET NAMES")) {
            String charsetName = sql.replace("SET NAMES", "").trim();
            switch (charsetName) {
                case "utf8mb4":
                    charset = Charset.forName("utf8");
                    break;
                default:
                    break;
            }
            ThreadPublic.putCharset(localPid, charset);

        } else if (sql.toUpperCase(Locale.ROOT).startsWith("SELECT") || sql.toUpperCase(Locale.ROOT).contains("SELECT")) {
            sql = replaceSql(sql);
            byte[] newSqlBytes = sql.getBytes();
            int sqlLength = newSqlBytes.length + 1;
            byteBuf.writerIndex(0);
            byteBuf.writeByte((byte) (sqlLength & 0xff));
            byteBuf.writeByte((byte) (sqlLength >>> 8));
            byteBuf.writeByte((byte) (sqlLength >>> 16));
            byteBuf.writeByte((byte) 0);
            byteBuf.writeByte(MySQLPacket.COM_QUERY);
            byteBuf.writeBytes(newSqlBytes);
        } else if (sql.toUpperCase(Locale.ROOT).startsWith("DELETE")) {
            delete(ctx, config, channel, sql);
        } else if (sql.toUpperCase(Locale.ROOT).startsWith("UPDATE")) {
            update(ctx, config, channel, sql);
        } else if (sql.toUpperCase(Locale.ROOT).startsWith("INSERT")) {
            insert(ctx, config, channel, sql);
        }
        byteBuf.readerIndex(0);
        channel.writeAndFlush(byteBuf);


    }

    /**
     * 获取数据包长度
     * @param datas
     * @return
     */
    int getDataLength(byte[] datas) {
        return (datas[0] & 0xff) + ((datas[1] & 0xff) << 8) + ((datas[2] & 0xff) << 16);
    }

    /**
     * 替换查询语句的sql
     * @param sql
     * @return
     */
    public String replaceSql(String sql) {
        if (sql.toLowerCase(Locale.ROOT).startsWith("select") && sql.toLowerCase(Locale.ROOT).contains("from") && (!sql.toLowerCase(Locale.ROOT).contains("information_schema")) && (!sql.contains("*"))) {
            int select = sql.toLowerCase().indexOf("select");
            int form = sql.indexOf("from");
            String substring = sql.substring(select, form);
            String[] split = substring.split(",");
            List<String> list = new ArrayList<>();
            for (String s : split) {
                String select1 = s.replace("select", "");
                list.add(rule.replace("#field#", select1));
            }
            String join = StringUtils.join(list, ",");
            sql = "select" + " " + join + " " + sql.substring(form);
            log.debug("替换后的的sql:{}", sql);
            return sql;
        }
        return sql;
    }
}

测试方案

连接方法

代理访问地址

原始效果

在这里插入图片描述

通过代码服务器及端口的访问效果**(注意不能用select * ,的解析这部分代码需要配置,在这里做不合适,且只支持字符串类型的字段)*

在这里插入图片描述

通过代理服务器端口进行增删除改效果

在这里插入图片描述
这里统一只做了日志记录,可以做其他操作,阻断,修改或告警等动作可以自行扩展实现。

postgrepsql实现

postgrepsql数据包结构组成

来源于dbeaver,jdbc,idea工具连接的数据包结构

在这里插入图片描述

来源于navicate,psql工具连接的数据包结构

在这里插入图片描述

代码解析


/**
 * @description: some desc
 * @author: yx
 * @date: 2021/12/7 15:40
 * 关键代码在这里,把sql写入buffer并设置头
 * QueryExecutorImpl.sendParse
 * postgrepsql的数据包自带的简化格式版最大长度为64字节,而且服务段接收也是最大长度为64字节
 * 暂时没有处理中文格式问题,后续测试再处理
 * 现阶注意只支持具体sql语句,不支持*号的脱敏
 */
@Slf4j
public class PostGrepSqlParser extends DefaultSqlParser {
    String charset = "utf8";
    String rule = "concat(SUBSTR(#field#,1,CHAR_LENGTH(#field#)/2),substr('*************',CHAR_LENGTH(#field#)/2,CHAR_LENGTH(#field#)/2)) as #field#";
    Map<String, ByteBuf> bufferMap = new HashMap();
    Set<String> isWait = new HashSet<>();

    @Override
    public void dealChannel(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, Object msg) {
        ByteBuf readBuffer = (ByteBuf) msg;
        int oldByteLength = readBuffer.readableBytes();
        InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
        String hostString = remoteAddress.getHostString();
        int remotePort = remoteAddress.getPort();
        if (Objects.equals(hostString, config.getRemoteAddr()) && Objects.equals(config.getRemotePort(), remotePort) && oldByteLength > 8) {
            dealType(ctx, config, channel, msg);
        } else {
            channel.writeAndFlush(readBuffer);
        }
    }

    public void dealType(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, Object msg) {
        ByteBuf readBuffer = (ByteBuf) msg;
        //首先第一步看看第一位是不是80或81
        byte startByte = readBuffer.getByte(0);
        String localPid = channel.localAddress().toString();
        if (startByte == 80 || startByte == 81 || isWait.contains(localPid)) {
            //如果是服务端发送的消息远程地址为空
            //只有发送给数据库的数据才需要进行处理
            int readableBytes = readBuffer.readableBytes();
            //第一步先获取会话的id,如果当前会话的pid没有被结束则直接把所有的数据写入到缓冲区buffer里面
            if (bufferMap.containsKey(localPid)) {
                ByteBuf byteBuf = bufferMap.get(localPid);
                //如果写入完全了则直接进行sql解析
                int index = readBuffer.writerIndex();
                byte[] headerBytes = new byte[4];
                byteBuf.getBytes(1, headerBytes);
                //获取长度
                byte[] tmpBytes = new byte[index];
                readBuffer.getBytes(0, tmpBytes);
                byteBuf.writeBytes(tmpBytes);
                int readableBytesNew = byteBuf.readableBytes();
                if (readableBytesNew >= byteBuf.writerIndex()) {
                    dealBytes(ctx, config, channel, byteBuf);
                    isWait.remove(localPid);
                    bufferMap.remove(localPid);

                }
            } else {
                //取第一位,如果是80表示从jdbc和idea来的请求,数据复杂一点,如果是81表示从navicat和psql的客户端来的请求,结构稍微简单点
                byte aByte = readBuffer.getByte(0);
                byte[] headerBytes = new byte[4];
                readBuffer.getBytes(1, headerBytes);
                //获取长度
                int byteLength = getByteLength(headerBytes);
                if (readableBytes >= byteLength) {
                    dealBytes(ctx, config, channel, readBuffer);
                } else {
                    //说明数据包不完全,先继续接收数据包等接收完全后再处理sql
                    ByteBuf tmpBuffer = Unpooled.buffer(byteLength);
                    tmpBuffer.writeBytes(readBuffer);
                    bufferMap.put(localPid, tmpBuffer);
                    isWait.add(localPid);
                }
            }
        } else {
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        }
    }

    public void dealBytes(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, ByteBuf readBuffer) {
        int startByte = readBuffer.getByte(0);
        switch (startByte) {
            case 80:
                dealComplex(ctx, config, channel, readBuffer);
                break;
            case 81:
                dealSimple(ctx, config, channel, readBuffer);
                break;
            default:
                readBuffer.retain();
                channel.writeAndFlush(readBuffer);
                break;
        }
    }


    /**
     * 根据byte数组得到字符串长度
     *
     * @param data
     * @return
     */

    public static int getByteLength(byte[] data) {
        int result = 0;
        for (int i = 0; i < data.length; i++) {
            result += (data[i] & 0xff) << ((3 - i) * 8);
        }
        return result;
    }

    /**
     * 根据字符串长度去生成数组中的信息
     *
     * @param length
     * @param data
     */
    public static void setHeaderBytes(int length, byte[] data) {
        data[0] = (byte) (length >>> 24);
        data[1] = (byte) (length >>> 16);
        data[2] = (byte) (length >>> 8);
        data[3] = (byte) length;
    }

    /**
     * 处理简单的客户端,指navicat和psql客户端发送的请求
     * 先处理假设sql最长只有64个字节,长的后续再处理
     *
     * @param readBuffer
     */
    void dealSimple(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, ByteBuf readBuffer) {
        int oldByteLength = readBuffer.readableBytes();
        byte headByte = readBuffer.getByte(0);
        byte[] headerBytes = new byte[4];
        readBuffer.getBytes(1, headerBytes);
        //获取长度
        int byteLength = getByteLength(headerBytes);
        //读取数据
        byte[] oldSqlBytes = new byte[byteLength - 5];
        readBuffer.getBytes(5, oldSqlBytes);
        String oldSql = new String(oldSqlBytes);
        readBuffer.retain();
        if (oldSql.toLowerCase().startsWith("select") && (!oldSql.toLowerCase().contains("information_schema")) && (!oldSql.contains("*"))) {
            String newSql = replaceSql(oldSql);
            byte[] newSqlBytes = newSql.getBytes();
            setHeaderBytes(newSqlBytes.length + 5, headerBytes);
            readBuffer.writerIndex(0);
            readBuffer.writeByte(headByte);
            readBuffer.writeBytes(headerBytes);
            //这种数据包格式的服务端一次只能接收64个字节的包,比较恶心需要分多次发送
            //这里有很大优化空间,重心现在放在解析数据包上暂不处理,后续再优化
            for (int i = 0; i < newSqlBytes.length; i++) {
                readBuffer.writeByte(newSqlBytes[i]);
                int index = readBuffer.writerIndex();
                if (index == 64) {
                    channel.writeAndFlush(readBuffer);
                    readBuffer = Unpooled.buffer(64);
                }

            }
            //注意这里的结束位不能省略
            readBuffer.writeByte(0);
            channel.writeAndFlush(readBuffer);

        } else if (oldSql.toUpperCase(Locale.ROOT).startsWith("DELETE")) {
            delete(ctx, config, channel, oldSql);
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        } else if (oldSql.toUpperCase(Locale.ROOT).startsWith("UPDATE")) {
            update(ctx, config, channel, oldSql);
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        } else if (oldSql.toUpperCase(Locale.ROOT).startsWith("INSERT")) {
            insert(ctx, config, channel, oldSql);
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        } else {
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        }

    }

    /**
     * 解析sql进行替换
     *
     * @param sql
     * @return
     */
    public String replaceSql(String sql) {
        //这里有可能会出现select version等情况的sql,后续再处理,也可能就不会走到这里来先忽略
        try {
            int select = sql.toLowerCase().indexOf("select");
            int form = sql.indexOf("from");
            String substring = sql.substring(select, form);
            String[] split = substring.split(",");
            List<String> list = new ArrayList<>();
            for (String s : split) {
                String select1 = s.replace("select", "");
                list.add(rule.replace("#field#", select1));
            }
            String join = StringUtils.join(list, ",");
            sql = "select" + " " + join + " " + sql.substring(form);
            log.info("执行完sql替换即将执行的sql为:{}", sql);
        } catch (Exception e) {
            log.debug("替换sql失败,原sql为:{}", sql);
        }
        return sql;

    }

    /**
     * 处理复杂关系的数据,指jdbc和idea连接的请求
     * 这种客户端有允许发送大于64字节的数据了,没心情研究为什么
     *
     * @param readBuffer
     */
    void dealComplex(ChannelHandlerContext ctx, ProxyConfig config, Channel channel, ByteBuf readBuffer) {
        int oldByteLength = readBuffer.readableBytes();

        byte headByte = readBuffer.getByte(0);
        byte[] headerBytes = new byte[4];
        readBuffer.getBytes(1, headerBytes);
        //获取长度
        int byteLength = getByteLength(headerBytes);
        //读取数据
        byte[] oldSqlBytes = new byte[byteLength - 8];
        readBuffer.getBytes(6, oldSqlBytes);
        String oldSql = new String(oldSqlBytes);
        byte[] endBytes = new byte[oldByteLength - byteLength + 8 - 6];
        readBuffer.getBytes(byteLength - 8 + 6, endBytes);
        readBuffer.retain();
        if (oldSql.toLowerCase().contains("select") && (!oldSql.toLowerCase().contains("information_schema"))) {
            String newSql = replaceSql(oldSql);
            byte[] newSqlBytes = newSql.getBytes();
            setHeaderBytes(newSqlBytes.length + 8, headerBytes);
            readBuffer.writerIndex(0);
            readBuffer.writeByte(headByte);
            readBuffer.writeBytes(headerBytes);
            readBuffer.writeByte(0);
            readBuffer.writeBytes(newSqlBytes);
            readBuffer.writeBytes(endBytes);
        } else if (oldSql.toUpperCase(Locale.ROOT).startsWith("DELETE")) {
            delete(ctx, config, channel, oldSql);
        } else if (oldSql.toUpperCase(Locale.ROOT).startsWith("UPDATE")) {
            update(ctx, config, channel, oldSql);
        } else if (oldSql.toUpperCase(Locale.ROOT).startsWith("INSERT")) {
            insert(ctx, config, channel, oldSql);
        }
        readBuffer.readerIndex(0);
        channel.writeAndFlush(readBuffer);
    }
}




测试方案

连接方法

在这里插入图片描述

原始效果

在这里插入图片描述

通过代码服务器及端口的访问效果**(注意不能用select * ,的解析这部分代码需要配置,在这里做不合适,且只支持字符串类型的字段)*

在这里插入图片描述

通过代理服务器端口进行增删除改效果

在这里插入图片描述

mariadb实现

mariadb本就是mysql的一个分支,初步用mysql的报文解析器也可以实现基础功能,暂时用同一个解析器,后续有问题再扩展。

说明

由于oracle、sqlserver和gbase,sysbase,db2,达梦等其他数据库是非开源数据库,开源易被举报,这里不便展示他们的报文格式,只展示两种开源数据库的报文格式,MariaDb的报文与mysql报文格式很相似,有兴趣的同学可以自行研究,也可以私信大家一起探讨。

温馨提示:

sql不支持*的写法,只能写具体字段,因为这只是一个研究原理用的实验产品,很多细节还没处理,bug也还有很多,只给大家提供一个研究思路。
具体代码见github地址

本文纯属个人学习产物,因为网上一直没有相关资料所以分享出来给感兴趣的朋友一起研究,如有侵权请私信联系作者。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值