DBUtil工具类,适用于ORACLE数据库

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

import javax.annotation.Resource;
import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 数据库操作JDBC工具类
 */
@Slf4j
public class DBUtil {

    private final static String TYPE_ORACLE = "oracle";

    private final static String DB = "jdbc.db.";

    private final static String DB_USERNAME = "username";

    private final static String DB_PWD = "password";

    private final static String DB_DRIVER = "driver";

    private final static String DB_URL = "url";

    private final static String DB_IP = "ip";

    private final static String DB_PORT = "port";

    private final static String DB_TYPE = "type";

    private final static String DB_DBNAME = "dbName";

    private final static String REGEX_TYPE = "jdbc:(.*?):";

    private final static String REGEX_ORACLE = "jdbc:oracle:thin:@(.*?):(.*):(.*)";

    private final static String CONN_ORACLE = "jdbc:oracle:thin:@";

    private final static String SELECT_FROM = "select * from ";

    private static Map<String, String> dbConfig = new HashMap<String, String>();

    private static String dbType;

    @Resource
    private DataSource dataSource;

    public static synchronized Map<String, String> getDBConfig() {
        if (dbConfig == null || dbConfig.size() == 0) {
            if (dbConfig == null)
                dbConfig = new HashMap<String, String>();
            Map<String, String> propConfig = PropertyUtil.getPropertyMap(Const.DB_CONFIG);
            String propUrl = propConfig.get(DB + DB_URL);
            Matcher m = Pattern.compile(REGEX_TYPE).matcher(propUrl);
            while (m.find()) {
                dbType = m.group(1);
            }
            dbConfig.put(DB_TYPE, dbType);
            if (TYPE_ORACLE.equals(dbType)) {
                dbConfig.put(DB_DRIVER, propConfig.get(DB + DB_DRIVER));
                dbConfig.put(DB_USERNAME, propConfig.get(DB + DB_USERNAME));
                dbConfig.put(DB_PWD, propConfig.get(DB + DB_PWD));
                dbConfig.put(DB_PWD, propConfig.get(DB + DB_PWD));
                Matcher m_oracle = Pattern.compile(REGEX_ORACLE).matcher(propUrl);
                while (m_oracle.find()) {
                    dbConfig.put(DB_IP, m_oracle.group(1));
                    dbConfig.put(DB_PORT, m_oracle.group(2));
                    dbConfig.put(DB_DBNAME, m_oracle.group(3));
                }
            }
        }
        return dbConfig;
    }

    public static Connection getDBConn() throws Exception {
        Map<String, String> config = getDBConfig();
        Connection conn = getConn(
                config.get(DB_DRIVER),
                dbType,
                config.get(DB_USERNAME),
                config.get(DB_PWD),
                config.get(DB_IP),
                config.get(DB_PORT),
                config.get(DB_DBNAME)
        );
        return conn;
    }


    public static Connection getConn(String driver, String dbType, String username, String password, String ip,
                                     String port, String databaseName) throws Exception {
        String oracleDataSourceUrl = CONN_ORACLE + ip + ":" + port + ":" + databaseName;
        if (TYPE_ORACLE.equals(dbType)) {
            Class.forName(driver);
            return DriverManager.getConnection(
                    oracleDataSourceUrl,
                    username,
                    password);
        } else {
            return null;
        }
    }


    public List<String> getTables() throws Exception {
        return getTables(getDBConn());
    }


    public List<String> getTables(Connection conn) throws Exception {
        if (TYPE_ORACLE.equals(dbType)) {
            return getTablesOracle(conn);
        } else {
            return null;
        }
    }

    public List<String> getTablesOracle(Connection conn) {
        try {
            List<String> tableList = new ArrayList<String>();
            DatabaseMetaData meta = conn.getMetaData();
            ResultSet rs = meta.getTables(null, null, null, new String[]{"TABLE"});
            while (rs.next()) {
                tableList.add(rs.getString(3));
            }
            return tableList;
        } catch (Exception e) {
            log.error("getTablesOracle方法错误:" + e.getMessage());
            log.error("getTablesOracle方法错误:", e);
        } finally {
            closeConn(conn);
        }
        return null;
    }

    public List<TColumn> getTableColumnsByTableName(String tableName) throws Exception {
        return getTableColumnsByTableName(getDBConn(), tableName);
    }


    public List<TColumn> getTableColumnsByTableName(Connection conn, String tableName) throws Exception {
        return getTableColumns(conn, SELECT_FROM + tableName);
    }


    public List<TColumn> getTableColumns(String sqlStr) throws Exception {
        return getTableColumns(getDBConn(), sqlStr);
    }


    public static List<TColumn> getTableColumns(Connection conn, String sqlStr) throws Exception {
        String sql = SELECT_FROM + "(" + sqlStr + ") tcolumns where 0!=0";

        PreparedStatement pstmt = (PreparedStatement) conn
                .prepareStatement(sql);
        pstmt.execute();
        List<TColumn> columns = new ArrayList<TColumn>();
        ResultSetMetaData rsmd = (ResultSetMetaData) pstmt.getMetaData();
        for (int i = 1; i < rsmd.getColumnCount() + 1; i++) {
            columns.add(new TColumn(rsmd.getColumnName(i), rsmd.getColumnTypeName(i), rsmd.getPrecision(i),
                    rsmd.getScale(i), rsmd.isNullable(i)));
        }
        return columns;
    }


    public List<List<Object>> queryByTableName(String tableName) throws Exception {
        return queryByTableName(getDBConn(), tableName);
    }


    public List<List<Object>> queryByTableName(Connection conn, String tableName) throws Exception {
        return query(conn, SELECT_FROM + tableName);
    }


    public List<List<Object>> query(String sqlStr) throws Exception {
        return query(getDBConn(), sqlStr);
    }


    public List<List<Object>> query(Connection conn, String sqlStr) throws Exception {
        List<TColumn> columns = new ArrayList<TColumn>();
        List<List<Object>> dataList = new ArrayList<List<Object>>();
        Statement stmt = null;
        ResultSet rs = null;
        try {
            conn = getDBConn();
            stmt = conn.createStatement();
            rs = stmt.executeQuery(sqlStr);
            columns = getTableColumns(conn, sqlStr);

            List<Object> columnList = new ArrayList<Object>();
            for (TColumn tc : columns) {
                columnList.add(tc.getName());
            }
            dataList.add(columnList);
            while (rs.next()) {
                List<Object> oneDataList = new ArrayList<Object>();
                for (int i = 1; i < columns.size() + 1; i++) {
                    oneDataList.add(rs.getObject(i));
                }
                dataList.add(oneDataList);
            }
            conn.close();
        } catch (Exception e) {
            log.error("query方法错误:" + e.getMessage());
            log.error("query方法错误:", e);
        } finally {
            closeConn(conn);
        }
        return dataList;
    }


    public Page<List<Object>> queryByTableName(Connection conn, String tableName, Page<List<Object>> page)
            throws Exception {
        return query(conn, SELECT_FROM + tableName, page);
    }


    public Page<List<Object>> queryByTableName(String tableName, Page<List<Object>> page) throws Exception {
        return query(getDBConn(), SELECT_FROM + tableName, page);
    }


    public Page<List<Object>> query(String sqlStr, Page<List<Object>> page) throws Exception {
        return query(getDBConn(), sqlStr, page);
    }


    public Page<List<Object>> query(Connection conn, String sqlStr, Page<List<Object>> page) throws Exception {
        // 存放字段名
        List<TColumn> columns = new ArrayList<TColumn>();
        // 存放数据(从数据库读出来的一条条的数据)
        List<List<Object>> dataList = new ArrayList<List<Object>>();
        Statement stmt = null;
        ResultSet rs = null;
        String sqlPage = null;
        try {
            conn = getDBConn();
            stmt = conn.createStatement();

            rs = stmt.executeQuery(getCountSql(sqlStr));
            while (rs.next()) {
                page.setTotalRecord(rs.getInt(1));
                break;
            }

            if (TYPE_ORACLE.equals(dbType)) {
                sqlPage = getOraclePageSql(page, new StringBuffer(sqlStr));
            }

            rs = stmt.executeQuery(sqlPage);

            columns = getTableColumns(conn, sqlStr);

            List<Object> columnList = new ArrayList<Object>();
            for (TColumn tc : columns) {
                columnList.add(tc.getName());
            }
            dataList.add(columnList);
            while (rs.next()) {
                List<Object> oneDataList = new ArrayList<Object>();
                for (int i = 1; i < columns.size() + 1; i++) {
                    oneDataList.add(rs.getObject(i));
                }
                dataList.add(oneDataList);
            }
            page.setResults(dataList);
            conn.close();
        } catch (Exception e) {
            log.error("query方法错误:", e);
        } finally {
            closeConn(conn);
        }
        return page;
    }

    public int operate(String sqlStr) throws Exception {
        return operate(getDBConn(), sqlStr);
    }


    public int operate(Connection conn, String sqlStr) throws Exception {
        int res = 0;
        Statement stmt = null;
        try {
            stmt = conn.createStatement();
            res = stmt.executeUpdate(sqlStr);
        } catch (Exception e) {
            log.error("operate方法错误:", e);
        } finally {
            closeConn(conn);
        }
        return res;
    }

    private static String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {
        // 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
        int offset = (page.getPageNum() - 1) * page.getPageSize() + 1;
        sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ")
                .append(offset + page.getPageSize());
        sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
        return sqlBuffer.toString();
    }

    private static String getCountSql(String sql) {
        return "select count(*) from (" + sql + ")  countRecord";
    }

    public static void closeConn(Connection conn) {
        try {
            if (conn != null)
                conn.close();
        } catch (Exception e) {
            log.error("closeConn方法错误:", e);
        }
    }

    /**
     * 执行Sql(分页)
     *
     * @param sqlStr sql语句
     * @param page   分页
     * @return
     */
    public static Map<String, Object> executeSQL(String sqlStr, Page<List<Object>> page) {

        Map<String, Object> result = new HashMap<String, Object>(10);
        // 存放字段名
        List<TColumn> columns = new ArrayList<TColumn>();
        // 存放数据(从数据库读出来的一条条的数据)
        List<List<Object>> dataList = new ArrayList<List<Object>>();
        Statement stmt = null;
        ResultSet rs = null;
        String sqlPage = null;
        Connection conn = null;
        int lines = 0;
        //请求起始时间_毫秒
        long startTime = System.currentTimeMillis();
        long rTime = 0;
        try {
            if (StringUtils.isNotBlank(sqlStr)) {
                result.put("executeSQL", sqlStr);
                conn = getDBConn();
                stmt = conn.createStatement();
                if (isQuerySql(sqlStr)) {
                    //判断是否为查询语句
                    rs = stmt.executeQuery(getCountSql(sqlStr));
                    while (rs.next()) {
                        page.setTotalRecord(rs.getInt(1));
                        break;
                    }
                    sqlPage = getOraclePageSql(page, new StringBuffer(sqlStr));
                    rs = stmt.executeQuery(sqlPage);
                    //请求结束时间_毫秒
                    long endTime = System.currentTimeMillis();
                    rTime = endTime - startTime;
                    result.put("rTime", rTime);

                    columns = getTableColumns(conn, sqlStr);

                    List<Object> columnList = new ArrayList<Object>();
                    for (TColumn tc : columns) {
                        columnList.add(tc.getName());
                    }
                    dataList.add(columnList);
                    while (rs.next()) {
                        List<Object> oneDataList = new ArrayList<Object>();
                        for (int i = 1; i < columns.size() + 1; i++) {
                            oneDataList.add(rs.getObject(i));
                        }
                        dataList.add(oneDataList);
                    }
                    page.setResults(dataList);
                    result.put("list", page);
                    result.put("type", "query");
                } else {
                    lines = stmt.executeUpdate(sqlStr);
                    //请求结束时间_毫秒
                    long endTime = System.currentTimeMillis();
                    rTime = endTime - startTime;
                    result.put("type", "operate");
                }
                //成功返回 1
                result.put("res", "1");
            } else {
                //失败返回 0
                result.put("res", "0");
                result.put("resMsg", "sql语句不能为空");
            }
        } catch (Exception e) {
            //失败返回 0
            result.put("res", "0");
            result.put("resMsg", e.getMessage());
            log.error("executeSQL方法错误:", e);
        } finally {
            closeStmt(stmt);
            closeConn(conn);
        }
        result.put("rTime", rTime);
        result.put("lines", lines);
        return result;
    }

    /**
     * 获取sql语句是否为查询
     */
    public static boolean isQuerySql(String sqlStr) {
        if (sqlStr.trim().toLowerCase().startsWith("select"))
            return true;
        return false;
    }

    public static void closeStmt(Statement stmt) {
        try {
            if (stmt != null)
                stmt.close();
        } catch (Exception e) {
            log.error("closeStmt方法错误:" + e.getMessage());
            log.error("closeStmt方法错误:", e);
        }
    }

    /**
     * 执行Sql(不分页)
     *
     * @param sqlStr sql语句
     * @return
     */
    public static Map<String, Object> executeSQL(String sqlStr) {
        Map<String, Object> result = new HashMap<String, Object>();
        // 存放字段名
        List<TColumn> columns = new ArrayList<TColumn>();
        // 存放数据(从数据库读出来的一条条的数据)
        List<List<Object>> dataList = new ArrayList<List<Object>>();
        Statement stmt = null;
        ResultSet rs = null;
        Connection conn = null;
        int lines = 0;
        //请求起始时间_毫秒
        long startTime = System.currentTimeMillis();
        long rTime = 0;
        try {
            if (StringUtils.isNotBlank(sqlStr)) {
                result.put("executeSQL", sqlStr);
                conn = getDBConn();
                stmt = conn.createStatement();
                //判断是否为查询语句
                if (isQuerySql(sqlStr)) {
                    rs = stmt.executeQuery(getCountSql(sqlStr));
                    while (rs.next()) {
                        result.put("count", rs.getInt(1));
                        break;
                    }
                    rs = stmt.executeQuery(sqlStr);
                    //请求结束时间_毫秒
                    long endTime = System.currentTimeMillis();
                    rTime = endTime - startTime;
                    result.put("rTime", rTime);
                    columns = getTableColumns(conn, sqlStr);
                    List<Object> columnList = new ArrayList<Object>();
                    for (TColumn tc : columns) {
                        columnList.add(tc.getName());
                    }
                    dataList.add(columnList);
                    while (rs.next()) {
                        List<Object> oneDataList = new ArrayList<Object>();
                        for (int i = 1; i < columns.size() + 1; i++) {
                            oneDataList.add(rs.getObject(i));
                        }
                        dataList.add(oneDataList);
                    }
                    result.put("list", dataList);
                    result.put("type", "query");
                } else {
                    lines = stmt.executeUpdate(sqlStr);
                    //请求结束时间_毫秒
                    long endTime = System.currentTimeMillis();
                    rTime = endTime - startTime;
                    result.put("type", "operate");
                }
                //成功返回 1
                result.put("res", "1");
            } else {
                //失败返回 0
                result.put("res", "0");
                result.put("resMsg", "sql语句不能为空");
            }
        } catch (Exception e) {
            //失败返回 0
            result.put("res", "0");
            result.put("resMsg", e.getMessage());
            log.error("executeSQL方法错误:", e);
        } finally {
            closeStmt(stmt);
            closeConn(conn);
        }
        result.put("rTime", rTime);
        result.put("lines", lines);
        return result;
    }

    public DataSource getDataSource() {
        return dataSource;
    }

    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }


    /**
     * 测试用例,不用关注
     * @param args
     */
    public static void main(String[] args) {
        /**
         * 1. 测试通过url获取数据库类型
         */
        String propUrl = "jdbc:oracle:thin:@127.0.0.1:1521:helowin";
        String REGEX_TYPE = "jdbc:(.*?):";
        Matcher m = Pattern.compile(REGEX_TYPE).matcher(propUrl);
        while (m.find()) {
            dbType = m.group(1);
            System.out.println(dbType);
        }
        /**
         * 2. 测试通过sql查询
         */
        Page<List<Object>> page = new Page<>();
        page.setPageNum(1);
        page.setPageSize(1);
        Map<String, Object> map = executeSQL(
                "select * from test11",
                page);
        System.out.println(map);
    }


}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
`dbutils`是Databricks提供的一个工具类,可以在Databricks平台上轻松地访问和操作多种类型的数据源。在使用`dbutils`访问JDBC数据源时,您可以按照以下步骤进行操作: 1. 首先,您需要在Databricks的Cluster页面上安装JDBC驱动程序。您可以将JDBC驱动程序上传到Databricks的“FileStore”,然后使用以下代码将其安装到集群中: ```python dbutils.fs.cp("file:/path/to/jdbc_driver.jar", "dbfs:/mnt/jdbc_driver/jdbc_driver.jar") ``` 2. 然后,您可以使用以下代码来创建JDBC连接: ```python jdbcHostname = "your_jdbc_hostname" jdbcDatabase = "your_jdbc_database_name" jdbcPort = 1433 jdbcUsername = "your_jdbc_username" jdbcPassword = "your_jdbc_password" jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbcHostname, jdbcPort, jdbcDatabase) connectionProperties = { "user" : jdbcUsername, "password" : jdbcPassword, "driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver" } jdbcDF = spark.read.jdbc(url=jdbcUrl, table="your_jdbc_table_name", properties=connectionProperties) ``` 请将上述代码中的"your_jdbc_hostname","your_jdbc_database_name","your_jdbc_username","your_jdbc_password"和"your_jdbc_table_name"替换为您自己的JDBC连接参数和需要读取的数据表名称。此外,您需要将"com.microsoft.sqlserver.jdbc.SQLServerDriver"替换为您使用的JDBC驱动程序的类名。 3. 最后,您可以使用`jdbcDF`变量来访问JDBC数据源的数据。例如,您可以使用以下代码来显示数据表中的内容: ```python display(jdbcDF) ``` 希望这些代码可以帮助您在Databricks上使用`dbutils`访问JDBC数据源。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值