Mybaits执行sql文件

1. 批量执行的MyScriptRunner类

import org.apache.ibatis.jdbc.RuntimeSqlException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.io.StringWriter;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 批量执行mysql的sql执行器类
 * @author 
 * @date 2023年10月10日
 */
public class MyScriptRunner {

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    private static final String LINE_SEPARATOR = System.lineSeparator();

    private static final String DEFAULT_DELIMITER = ";";

    private static final Pattern DELIMITER_PATTERN = Pattern.compile("^\\s*((--)|(//))?\\s*(//)?\\s*@DELIMITER\\s+([^\\s]+)", Pattern.CASE_INSENSITIVE);

    private final Connection connection;

    private boolean stopOnError;
    private boolean throwWarning;
    private boolean autoCommit;
    private boolean sendFullScript;
    private boolean removeCRs;
    private boolean escapeProcessing = true;
    private StringWriter msg;


    private PrintWriter logWriter = new PrintWriter(System.out);
    private PrintWriter errorLogWriter = new PrintWriter(System.err);

    private String delimiter = DEFAULT_DELIMITER;
    private boolean fullLineDelimiter;

    public MyScriptRunner(Connection connection) {
        this.connection = connection;
    }

    public void setStopOnError(boolean stopOnError) {
        this.stopOnError = stopOnError;
    }

    public void setThrowWarning(boolean throwWarning) {
        this.throwWarning = throwWarning;
    }

    public void setAutoCommit(boolean autoCommit) {
        this.autoCommit = autoCommit;
    }

    public StringWriter getMsg() {
        return msg;
    }

    public void setMsg(StringWriter msg) {
        this.msg = msg;
    }

    public void setSendFullScript(boolean sendFullScript) {
        this.sendFullScript = sendFullScript;
    }

    public void setRemoveCRs(boolean removeCRs) {
        this.removeCRs = removeCRs;
    }

    /**
     * Sets the escape processing.
     * @param escapeProcessing
     *          the new escape processing
     * @since 3.1.1
     */
    public void setEscapeProcessing(boolean escapeProcessing) {
        this.escapeProcessing = escapeProcessing;
    }

    public void setLogWriter(PrintWriter logWriter) {
        this.logWriter = logWriter;
    }

    public void setErrorLogWriter(PrintWriter errorLogWriter) {
        this.errorLogWriter = errorLogWriter;
    }

    public void setDelimiter(String delimiter) {
        this.delimiter = delimiter;
    }

    public void setFullLineDelimiter(boolean fullLineDelimiter) {
        this.fullLineDelimiter = fullLineDelimiter;
    }

    public void runScript(Reader reader) {
        setAutoCommit();

        try {
            if (sendFullScript) {
                executeFullScript(reader);
            } else {
                executeLineByLine(reader);
            }
        } finally {
            rollbackConnection();
        }
    }

    /**
     * 执行sql
     * @param reader 读取字符流的类
     */
    private void executeFullScript(Reader reader) {
        StringBuilder script = new StringBuilder();
        try {
            BufferedReader lineReader = new BufferedReader(reader);
            String line;
            int count=0;
            String command="";
            while ((line = lineReader.readLine()) != null) {
                script.append(line);
                script.append(LINE_SEPARATOR);
                count++;
                //注意处理量不要设置大于mysql的max_allowed_packet
                if(count % 1000 == 0){
                    command=script.toString();
                    println(command);
                    executeStatement(command);
                    script.setLength(0);
                }
            }
            //最后一次数据的执行
            command=script.toString();
            if(command.length() != 0 ){
                println(command);
                executeStatement(command);
                script.setLength(0);
            }
            logger.info("批处理务提交中,请耐心等待...");
            commitConnection();
        } catch (Exception e) {
            logger.error("批处理事务回滚中请耐心等待...");
            String message = "Error executing: " + script + ".  Cause: " + e;
            printlnError(message);
            throw new RuntimeSqlException(message, e);
        }
    }

    /**
     * 逐行执行
     * @param reader 读取字符流的类
     */
    private void executeLineByLine(Reader reader) {
        StringBuilder command = new StringBuilder();
        try {
            BufferedReader lineReader = new BufferedReader(reader);
            String line;
            while ((line = lineReader.readLine()) != null) {
                handleLine(command, line);
            }
            if(msg.toString().length() == 0){
                logger.info("逐行事务提交中,请耐心等待...");
                commitConnection();
            }else {
                logger.info("逐行事务回滚中,请耐心等待...");
            }
            checkForMissingLineTerminator(command);
        } catch (Exception e) {
            String message = "Error executing: " + command + ".  Cause: " + e;
            printlnError(message);
            throw new RuntimeSqlException(message, e);
        }
    }

    /**
     * @deprecated Since 3.5.4, this method is deprecated. Please close the {@link Connection} outside of this class.
     */
    @Deprecated
    public void closeConnection() {
        try {
            connection.close();
        } catch (Exception e) {
            // ignore
        }
    }

    /**
     * 设置自动提交十五
     */
    private void setAutoCommit() {
        try {
            if (autoCommit != connection.getAutoCommit()) {
                connection.setAutoCommit(autoCommit);
            }
        } catch (Throwable t) {
            throw new RuntimeSqlException("Could not set AutoCommit to " + autoCommit + ". Cause: " + t, t);
        }
    }

    private void commitConnection() {
        try {
            if (!connection.getAutoCommit()) {
                connection.commit();
            }
        } catch (Throwable t) {
            throw new RuntimeSqlException("Could not commit transaction. Cause: " + t, t);
        }
    }

    private void rollbackConnection() {
        try {
            if (!connection.getAutoCommit()) {
                connection.rollback();
            }
        } catch (Throwable t) {
           logger.error("重新链接出现异常:{}", t.getMessage());
        }
    }

    private void checkForMissingLineTerminator(StringBuilder command) {
        if (command != null && command.toString().trim().length() > 0) {
            throw new RuntimeSqlException("Line missing end-of-line terminator (" + delimiter + ") => " + command);
        }
    }

    /**
     * 逐行执行
     * @param command sql
     * @param line 行
     * @throws SQLException
     */
    private void handleLine(StringBuilder command, String line) throws SQLException {
        String trimmedLine = line.trim();
        if (lineIsComment(trimmedLine)) {//当前行有注释追加
            Matcher matcher = DELIMITER_PATTERN.matcher(trimmedLine);
            if (matcher.find()) {
                delimiter = matcher.group(5);
            }
            println(trimmedLine);
        } else if (commandReadyToExecute(trimmedLine)) {//当前行是否有结束符,停止追加
            command.append(line, 0, line.lastIndexOf(delimiter));
            command.append(LINE_SEPARATOR);
            println(command);
            executeStatement(command.toString());

            command.setLength(0);
        } else if (trimmedLine.length() > 0) {//没有碰到结束符一直追加
            command.append(line);
            command.append(LINE_SEPARATOR);
        }
    }

    private boolean lineIsComment(String trimmedLine) {
        return trimmedLine.startsWith("//") || trimmedLine.startsWith("--");
    }

    private boolean commandReadyToExecute(String trimmedLine) {
        // issue #561 remove anything after the delimiter
        return !fullLineDelimiter && trimmedLine.contains(delimiter) || fullLineDelimiter && trimmedLine.equals(delimiter);
    }

    private void executeStatement(String command) throws SQLException {
        try (Statement statement = connection.createStatement()) {
            statement.setEscapeProcessing(escapeProcessing);
            String sql = command;
            if (removeCRs) {
                sql = sql.replace("\r\n", "\n");
            }
            try {
                boolean hasResults = statement.execute(sql);
                while (!(!hasResults && statement.getUpdateCount() == -1)) {
                    checkWarnings(statement);
                    printResults(statement, hasResults);
                    hasResults = statement.getMoreResults();
                }
            } catch (SQLWarning e) {
                throw e;
            } catch (SQLException e) {
                if (stopOnError) {
                    throw e;
                } else {
                    String message = "Error executing: " + command + ".  Cause: " + e;
                    printlnError(message);
                }
            }
        }
    }

    private void checkWarnings(Statement statement) throws SQLException {
        if (!throwWarning) {
            return;
        }
        // In Oracle, CREATE PROCEDURE, FUNCTION, etc. returns warning
        // instead of throwing exception if there is compilation error.
        SQLWarning warning = statement.getWarnings();
        if (warning != null) {
            throw warning;
        }
    }

    private void printResults(Statement statement, boolean hasResults) {
        if (!hasResults) {
            return;
        }
        try (ResultSet rs = statement.getResultSet()) {
            ResultSetMetaData md = rs.getMetaData();
            int cols = md.getColumnCount();
            for (int i = 0; i < cols; i++) {
                String name = md.getColumnLabel(i + 1);
                print(name + "\t");
            }
            println("");
            while (rs.next()) {
                for (int i = 0; i < cols; i++) {
                    String value = rs.getString(i + 1);
                    print(value + "\t");
                }
                println("");
            }
        } catch (SQLException e) {
            printlnError("Error printing results: " + e.getMessage());
        }
    }

    private void print(Object o) {
        if (logWriter != null) {
            logWriter.print(o);
            logWriter.flush();
        }
    }

    private void println(Object o) {
        if (logWriter != null) {
            logWriter.println(o);
            logWriter.flush();
        }
    }

    private void printlnError(Object o) {
        if (errorLogWriter != null) {
            errorLogWriter.println();
            errorLogWriter.println(o);
            errorLogWriter.flush();
        }
    }
}

2. 批量执行的类

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.io.*;
import java.sql.Connection;

/**
 * 执行批量sql写入的类
 * @author 
 * @date 2023年10月10日
 */
@Component
public class ExecuteSqlUtils {

    private Logger logger = LoggerFactory.getLogger(this.getClass());


    /**
     * 链接数据的数据源
     */
    private DataSource dataSource;


    public ExecuteSqlUtils() {
    }

    @Autowired
    public ExecuteSqlUtils(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    /**
     * 使用MyScriptRunner执行SQL脚本
     * 1.第一次执行采用批处理,批处理执行失败将会自动转为逐行执行检索错误的sql打印进日志
     */
    public void doExecuteSql(String[] sqlPath) {
        //通过数据源获取数据库链接
        Connection connection = DataSourceUtils.getConnection(dataSource);
        //创建脚本执行器
        MyScriptRunner scriptRunner = new MyScriptRunner(connection);
        //关闭Info日志
        scriptRunner.setLogWriter(null);
        //打印错误的日志信息
        scriptRunner.setErrorLogWriter(null);
        //报错停止运行
        scriptRunner.setStopOnError(true);
        //设置手动提交事务
        scriptRunner.setAutoCommit(false);
        //开启批处理模式
        scriptRunner.setSendFullScript(true);

        logger.info("批处理执行中");
        boolean b = batchSql(sqlPath, scriptRunner,0);
        //true 批处理出现异常,转为逐行执行
        if(b){
            logger.info("逐行检索SQL启动");            ;
            //打印错误的日志信息
            StringWriter errorSql = new StringWriter();
            scriptRunner.setMsg(errorSql);
            PrintWriter print = new PrintWriter(errorSql);
            scriptRunner.setErrorLogWriter(print);
            //报错不要停止运行
            scriptRunner.setStopOnError(false);
            //设置手动提交事务
            scriptRunner.setAutoCommit(false);
            //关闭批处理
            scriptRunner.setSendFullScript(false);

            batchSql(sqlPath, scriptRunner,1);
            String errorMsg = errorSql.toString();
            //逐行执行所有SQL,打印所有错误的SQL
            if(errorMsg.length() != 0){
                logger.error("--------------请修改以下错误sql再次执行脚本--------------");            ;
                logger.error("sql错误:【{}】", errorMsg);
            }else{
                //处理量设置大于mysql的max_allowed_packet将转为逐行执行
                logger.info("逐行插入成功!");
            }
        }else {
            logger.info("批处理成功!");
        }

    }

    /**
     * 批量执行sql语句
     * @param filePaths sql文件路径
     * @param scriptRunner mybatis的SQL执行器
     * @param mark mark
     * @return 执行是否成功
     */
    private boolean batchSql(String[] filePaths,MyScriptRunner scriptRunner,int mark){
        for (String path : filePaths) {
            try ( FileInputStream fileInputStream=new FileInputStream(path)){
                InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
                BufferedReader bufferedReader= new BufferedReader(inputStreamReader);
                try {
                    scriptRunner.runScript(bufferedReader);
                } catch (Exception e) {
                    if(mark == 0){
                        logger.error("批处理执行失败,异常信息:{}", e.getMessage());            ;
                        return true;
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        return false;
    }


}

数据库链接需要加入&allowMultiQueries=true配置

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值