ScriptRunner是mybatis中用于执行sql脚本的类。这个类也经常在java中直接使用用来执行sql.
首先来看下mybatis中test用例。
//创建连接
public static PooledDataSource createPooledDataSource(String resource) throws IOException {
Properties props = Resources.getResourceAsProperties(resource);
PooledDataSource ds = new PooledDataSource();
ds.setDriver(props.getProperty("driver"));
ds.setUrl(props.getProperty("url"));
ds.setUsername(props.getProperty("username"));
ds.setPassword(props.getProperty("password"));
return ds;
}
//运行脚本
public static void runScript(DataSource ds, String resource) throws IOException, SQLException {
Connection connection = ds.getConnection();
try {
ScriptRunner runner = new ScriptRunner(connection);
runner.setAutoCommit(true);
runner.setStopOnError(false);
runner.setLogWriter(null);
runner.setErrorLogWriter(null);
runScript(runner, resource);
} finally {
connection.close();
}
}
public static void runScript(ScriptRunner runner, String resource) throws IOException, SQLException {
Reader reader = Resources.getResourceAsReader(resource);
try {
runner.runScript(reader);
} finally {
reader.close();
}
}
下面来看下ScriptRunner的源码解析:
package org.apache.ibatis.jdbc;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.io.Reader;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* @author Clinton Begin
*/
//sql脚本运行
public class ScriptRunner {
//换行符
private static final String LINE_SEPARATOR = System.getProperty("line.separator", "\n");
//分隔符
private static final String DEFAULT_DELIMITER = ";";
private static final Pattern DELIMITER_PATTERN = Pattern.compile("^\\s*((--)|(//))?\\s*(//)?\\s*@DELIMITER\\s+([^\\s]+)", Pattern.CASE_INSENSITIVE);
//数据据连接类
private final Connection connection;
//是否出错时停止
private boolean stopOnError;
//是否抛出警告
private boolean throwWarning;
//自动提交事务
private boolean autoCommit;
//是否批量发送脚本
private boolean sendFullScript;
//是否开启回车换行替换,相当于replaceAll("\r\n", "\n");
private boolean removeCRs;
//是否开启转意替换
private boolean escapeProcessing = true;
//日志
private PrintWriter logWriter = new PrintWriter(System.out);
private PrintWriter errorLogWriter = new PrintWriter(System.err);
private String delimiter = DEFAULT_DELIMITER;
private boolean fullLineDelimiter;
public ScriptRunner(Connection connection) {
this.connection = connection;
}
public void setStopOnError(boolean stopOnError) {
this.stopOnError = stopOnError;
}
//set方法省略
/**
* 脚本 执行
*/
public void runScript(Reader reader) {
//设置事务提交方式
setAutoCommit();
try {
//是否批量执行脚本
if (sendFullScript) {
//批量执行脚本
executeFullScript(reader);
} else {
//一行一行执行脚本
executeLineByLine(reader);
}
} finally {
//回滚
rollbackConnection();
}
}
/**
* 批量执行脚本
*/
private void executeFullScript(Reader reader) {
StringBuilder script = new StringBuilder();
try {
BufferedReader lineReader = new BufferedReader(reader);
String line;
while ((line = lineReader.readLine()) != null) {
script.append(line);
script.append(LINE_SEPARATOR);
}
String command = script.toString();
//打印脚本
println(command);
//执行脚本
executeStatement(command);
//事务提交
commitConnection();
} catch (Exception e) {
String message = "Error executing: " + script + ". Cause: " + e;
printlnError(message);
throw new RuntimeSqlException(message, e);
}
}
/**
* 一行一行执行脚本
*/
private void executeLineByLine(Reader reader) {
StringBuilder command = new StringBuilder();
try {
BufferedReader lineReader = new BufferedReader(reader);
String line;
while ((line = lineReader.readLine()) != null) {
//执行命令
handleLine(command, line);
}
//事务提交
commitConnection();
//检查是否有结束符
checkForMissingLineTerminator(command);
} catch (Exception e) {
String message = "Error executing: " + command + ". Cause: " + e;
printlnError(message);
throw new RuntimeSqlException(message, e);
}
}
public void closeConnection() {
try {
connection.close();
} catch (Exception e) {
// ignore
}
}
//是否自动提交
private void setAutoCommit() {
try {
if (autoCommit != connection.getAutoCommit()) {
connection.setAutoCommit(autoCommit);
}
} catch (Throwable t) {
throw new RuntimeSqlException("Could not set AutoCommit to " + autoCommit + ". Cause: " + t, t);
}
}
//提交事务
private void commitConnection() {
try {
if (!connection.getAutoCommit()) {
connection.commit();
}
} catch (Throwable t) {
throw new RuntimeSqlException("Could not commit transaction. Cause: " + t, t);
}
}
//回滚
private void rollbackConnection() {
try {
if (!connection.getAutoCommit()) {
connection.rollback();
}
} catch (Throwable t) {
// ignore
}
}
//检查是否有结束符
private void checkForMissingLineTerminator(StringBuilder command) {
if (command != null && command.toString().trim().length() > 0) {
throw new RuntimeSqlException("Line missing end-of-line terminator (" + delimiter + ") => " + command);
}
}
/**
* 逐行执行sql脚本
*/
private void handleLine(StringBuilder command, String line) throws SQLException {
String trimmedLine = line.trim();
//是否已注释
if (lineIsComment(trimmedLine)) {
Matcher matcher = DELIMITER_PATTERN.matcher(trimmedLine);
if (matcher.find()) {
delimiter = matcher.group(5);
}
println(trimmedLine);
} else if (commandReadyToExecute(trimmedLine)) {
//是否是语句结尾处理,是则将脚本执行,否则继续拼接直到遇到结束符
command.append(line.substring(0, line.lastIndexOf(delimiter)));
command.append(LINE_SEPARATOR);
println(command);
//执行脚本
executeStatement(command.toString());
command.setLength(0);
} else if (trimmedLine.length() > 0) {
//不是语句结尾,继续拼装
command.append(line);
command.append(LINE_SEPARATOR);
}
}
private boolean lineIsComment(String trimmedLine) {
return trimmedLine.startsWith("//") || trimmedLine.startsWith("--");
}
private boolean commandReadyToExecute(String trimmedLine) {
// issue #561 remove anything after the delimiter
return !fullLineDelimiter && trimmedLine.contains(delimiter) || fullLineDelimiter && trimmedLine.equals(delimiter);
}
/**
* 执行数据库脚本
*/
private void executeStatement(String command) throws SQLException {
Statement statement = connection.createStatement();
try {
//是否开启转义
statement.setEscapeProcessing(escapeProcessing);
String sql = command;
//是否回车换行
if (removeCRs) {
sql = sql.replaceAll("\r\n", "\n");
}
try {
//执行sql
boolean hasResults = statement.execute(sql);
while (!(!hasResults && statement.getUpdateCount() == -1)) {
//是否抛出警告
checkWarnings(statement);
//打印结果
printResults(statement, hasResults);
hasResults = statement.getMoreResults();
}
} catch (SQLWarning e) {
throw e;
} catch (SQLException e) {
if (stopOnError) {
throw e;
} else {
String message = "Error executing: " + command + ". Cause: " + e;
printlnError(message);
}
}
} finally {
try {
statement.close();
} catch (Exception e) {
// Ignore to workaround a bug in some connection pools
// (Does anyone know the details of the bug?)
}
}
}
/**
* 检查是否抛出警告
*/
private void checkWarnings(Statement statement) throws SQLException {
if (!throwWarning) {
return;
}
// In Oracle, CREATE PROCEDURE, FUNCTION, etc. returns warning
// instead of throwing exception if there is compilation error.
SQLWarning warning = statement.getWarnings();
if (warning != null) {
throw warning;
}
}
/**
* 打印结果
*/
private void printResults(Statement statement, boolean hasResults) {
if (!hasResults) {
return;
}
try (ResultSet rs = statement.getResultSet()) {
ResultSetMetaData md = rs.getMetaData();
int cols = md.getColumnCount();
for (int i = 0; i < cols; i++) {
String name = md.getColumnLabel(i + 1);
print(name + "\t");
}
println("");
while (rs.next()) {
for (int i = 0; i < cols; i++) {
String value = rs.getString(i + 1);
print(value + "\t");
}
println("");
}
} catch (SQLException e) {
printlnError("Error printing results: " + e.getMessage());
}
}
private void print(Object o) {
if (logWriter != null) {
logWriter.print(o);
logWriter.flush();
}
}
private void println(Object o) {
if (logWriter != null) {
logWriter.println(o);
logWriter.flush();
}
}
private void printlnError(Object o) {
if (errorLogWriter != null) {
errorLogWriter.println(o);
errorLogWriter.flush();
}
}
}