spark中DataFrame结果保存到mysql,表不存在会自动建立

本例子是在网上资源的完善,修改了网上资源的几处bug,加入新的函数

package edu.xd.spark

import java.sql.{Date, Timestamp}
import java.util.Properties

import org.apache.log4j.Logger
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

object OperatorMySql {

  val logger: Logger = Logger.getLogger(this.getClass.getSimpleName)

  /**
   * 将dataframe所有类型(除id外)转换为string后,通过c3p0的连接池方式,向mysql写入数据
   *
   * @param tableName       表名
   * @param resultDateFrame datafream
   */
  def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame): Unit = {
    val colNumbsers = resultDateFrame.columns.length
    val sql = getInsertSql(tableName, colNumbsers)
    val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
    resultDateFrame.foreachPartition(
      (partitionRecords:Iterator[Row]) => {

      val conn = MySqlPoolManager.getMysqlManager.getConnection
      val prepareStatement = conn.prepareStatement(sql)
      val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%")
      try {
        conn.setAutoCommit(false)
        partitionRecords.foreach(record => {
          for (i <- 1 to colNumbsers) {
            val value = record.get(i - 1)
            val dateType = columnDataTypes(i - 1)
            if (value != null) {
              //prepareStatement.setString(i, value.toString)
              dateType match {
                case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
                case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
                case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
                case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
                case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
                case _ => throw new RuntimeException("nonsupport $ {dateType} !!!")
              }
            } else {
              metaData.absolute(i)
              prepareStatement.setNull(i, metaData.getInt("DATA_TYPE"))
            }
          }
          prepareStatement.addBatch()
        })
        prepareStatement.executeBatch()
        conn.commit()
      } catch {
        case e: Exception => println(s"@@ saveDFtoDBUsePool ${e.getMessage}")
      } finally {
        prepareStatement.close()
        conn.close()
      }
    })

  }

  /**
   * 拼接sql
   */
  def getInsertSql(tableName: String, colNumbers: Int): String = {
    var sqlStr = "insert into " + tableName + " values("
    for (i <- 1 to colNumbers) {
      sqlStr += "?"
      if (i != colNumbers) {
        sqlStr += ","
      }
    }
    sqlStr += ")"
    sqlStr
  }

  /**
   * 以元祖的额方式返回mysql属性信息
   *
   * @return
   */
  def getMysqlInfo: (String, String, String) = {
    val jdbcURL = PropertiyUtils.getFileProperties("", "")
    val userName = PropertiyUtils.getFileProperties("", "")
    val password = PropertiyUtils.getFileProperties("", "")
    (jdbcURL, userName, password)
  }

  /**
   * 从mysql中获取dataframe
   *
   * @param sqlContext     sqlContext
   * @param mysqlTableName 表名
   * @param queryCondition 查询条件
   * @return
   */
  def getDFFromeMysql(sqlContext: SQLContext, mysqlTableName: String, queryCondition: String = ""): DataFrame = {
    val (jdbcURL, userName, password) = getMysqlInfo
    val prop = new Properties()
    prop.put("user", userName)
    prop.put("password", password)
    //scala中其实equals和==是相同的,并不跟java中一样
    if (null == queryCondition || "" == queryCondition) {
      sqlContext.read.jdbc(jdbcURL, mysqlTableName, prop)
    } else {
      sqlContext.read.jdbc(jdbcURL, mysqlTableName, prop).where(queryCondition)
    }

  }

  /**
   * 删除数据表
   * @param mysqlTableName
   * @return
   */
  def dropMysqlTable( mysqlTableName: String): Boolean = {
    val conn = MySqlPoolManager.getMysqlManager.getConnection
    val preparedStatement = conn.createStatement()
    try {
      preparedStatement.execute(s"drop table $mysqlTableName")
    } catch {
      case e: Exception =>
        println(s"mysql drop MysqlTable error:${e.getMessage}")
        false
    } finally {
      preparedStatement.close()
      conn.close()
    }
  }

  /**
   * 从表中删除数据
   *
   * @param SQLContext
   * @param mysqlTableName 表名
   * @param condition      条件,直接从where后面开始
   * @return
   */
  def deleteMysqlTableData(SQLContext: SQLContext, mysqlTableName: String, condition: String): Boolean = {
    val conn = MySqlPoolManager.getMysqlManager.getConnection
    val preparedStatement = conn.createStatement()
    try {
      preparedStatement.execute(s"delete from $mysqlTableName where $condition")
    } catch {
      case e: Exception =>
        println(s"mysql delete MysqlTableNameData error:${e.getMessage}")
        false
    } finally {
      preparedStatement.close()
      conn.close()
    }
  }

  /**
   * 保存dataframe到mysql中,如果表不存在的话,会自动创建
   *
   * @param tableName
   * @param resultDataFrame
   */
  def saveDFtoDBCreateTableIfNotExists(tableName: String, resultDataFrame: DataFrame) = {
    //如果没有表,根据dataframe建表
    createTableIfNotExist(tableName, resultDataFrame)
    //验证数据表字段和dataframe字段个数和名称,顺序是否一致
    verifyFieldConsistency(tableName, resultDataFrame)
    //保存df
    saveDFtoDBUsePool(tableName, resultDataFrame)
  }

  /**
   * jiakai
   * @param tableName
   * @param df
   */
  def saveOrUpdateAndCreateTableIfNotExits(tableName:String,df:DataFrame): Unit = {
    createTableIfNotExist(tableName, df)
    val conn = MySqlPoolManager.getMysqlManager.getConnection
    val metaData = conn.getMetaData
    val colResultSet = metaData.getColumns(null, "%", tableName, "%")
    insertOrUpdateDFtoDBUserPoolWithMysqlSaveUpdateNew(tableName,df,df.columns)
  }

  /**
   * 如果表不存在则创建
   *
   * @param tableName
   * @param df
   * @return
   */
  def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = {
    val conn = MySqlPoolManager.getMysqlManager.getConnection
    val metaData = conn.getMetaData
    val colResultSet = metaData.getColumns(null, "%", tableName, "%")
    //如果没有该表,创建数据表
    if (!colResultSet.next()) {
      //构建表字符串
      val sb = new StringBuilder(s"CREATE TABLE if not exists  `$tableName` (")
      df.schema.fields.foreach(x => {
        if (x.name.equalsIgnoreCase("id")) {
          //如果字段名是id,则设置为主键,不为空,自增
          sb.append(s"`${x.name}` int(255) not null auto_increment primary key,")
        } else {
          x.dataType match {
            case _: ByteType => sb.append(s"`${x.name}` int(10) default null,")
            case _: ShortType => sb.append(s"`${x.name}` int(10) default null,")
            case _: IntegerType => sb.append(s"`${x.name}` int(20) default null,")
            case _: LongType => sb.append(s"`${x.name}` bigint(20) default null,")
            case _: BooleanType => sb.append(s"`${x.name}` tinyint default null,")
            case _: FloatType => sb.append(s"`${x.name}` float default null,")
            case _: DoubleType => sb.append(s"`${x.name}` double default null,")
            case _: StringType => sb.append(s"`${x.name}` varchar(64) default null,")
            case _: TimestampType => sb.append(s"`${x.name}` timestamp default current_timestamp,")
            case _: DateType => sb.append(s"`${x.name}` date default null,")
            case _ => throw new RuntimeException(s"non support ${x.dataType}!!!")
          }
        }
      })
      sb.append(") engine = InnDB default charset=utf8mb4" )
      val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString()
      println(sql_createTable)
      val statement = conn.createStatement()
      statement.execute(sql_createTable)
    }
  }


  /**
   * 拼接insertOrUpdate语句
   *
   * @param tableName
   * @param cols
   * @param updateColumns
   * @return
   */
  def getInsertOrUpdateSqlNew(tableName: String, cols: Array[String], updateColumns: Array[String]): String = {
    val colNumbers = cols.length
    var sqlStr = "insert into " + tableName + "("
    for (i <- 1 to colNumbers) {
      sqlStr += cols(i - 1)
      if (i != colNumbers) {
        sqlStr += ","
      }
    }
    sqlStr += ") values("
    for (i <- 1 to colNumbers) {
      sqlStr += "?"
      if (i != colNumbers) {
        sqlStr += ","
      }
    }
    sqlStr += ") on duplicate key update "
    updateColumns.foreach(str => {
      sqlStr += s"$str=values($str),"
    })
    sqlStr.substring(0, sqlStr.length - 1)
  }

  /**
   * 拼接insertOrUpdate语句
   *
   * @param tableName
   * @param cols
   * @param updateColumns
   * @return
   */
  def getInsertOrUpdateSql(tableName: String, cols: Array[String], updateColumns: Array[String]): String = {
    val colNumbers = cols.length
    var sqlStr = "insert into " + tableName + "("
    for (i <- 1 to colNumbers) {
      sqlStr += cols(i - 1)
      if (i != colNumbers) {
        sqlStr += ","
      }
    }
    sqlStr += ") values("
    for (i <- 1 to colNumbers) {
      sqlStr += "?"
      if (i != colNumbers) {
        sqlStr += ","
      }
    }
    sqlStr += ") on duplicate key update "
    updateColumns.foreach(str => {
      sqlStr += s"$str=?,"
    })
    sqlStr.substring(0, sqlStr.length - 1)
  }

  /**
   * jiakai使用mysql的saveUpdate
   * @param tableName
   * @param resultDateFrame 要入库的dataframe
   * @param updateColumns   要更新的字段
   */
  def insertOrUpdateDFtoDBUserPoolWithMysqlSaveUpdateNew(tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = {
    var status = true
    var count = 0
    val colNumbsers = resultDateFrame.columns.length
    val sql = getInsertOrUpdateSqlNew(tableName, resultDateFrame.columns, updateColumns)
    val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
    println(s"\n$sql")
    resultDateFrame.foreachPartition((partitionRecords:Iterator[Row]) => {
      val conn = MySqlPoolManager.getMysqlManager.getConnection
      val prepareStatement = conn.prepareStatement(sql)
      val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%")
      try {
        conn.setAutoCommit(false)
        partitionRecords.foreach(record => {
          //设置需要插入的字段
          for (i <- 1 to colNumbsers) {
            val value = record.get(i - 1)
            val dateType = columnDataTypes(i - 1)
            if (value != null) {
              prepareStatement.setString(i, value.toString)
              dateType match {
                case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
                case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
                case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
                case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
                case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
                case _ => throw new RuntimeException("nonsupport $ {dateType} !!!")
              }
            } else {
              metaData.absolute(i)
              prepareStatement.setNull(i, metaData.getInt("Data_Type"))
            }
          }
          //设置需要 更新的字段值
//          for (i <- 1 to updateColumns.length) {
//            val fieldIndex = record.fieldIndex(updateColumns(i - 1))
//            val value = record.get(i)
//            val dataType = columnDataTypes(fieldIndex)
//            println(s"\n更新字段值属性索引: $fieldIndex,属性值:$value,属性类型:$dataType")
//            if (value != null) {
//              dataType match {
//                case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
//                case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
//                case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
//                case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex))
//                case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex))
//                case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex))
//                case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex))
//                case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex))
//                case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex))
//                case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex))
//                case _ => throw new RuntimeException(s"no support ${dataType} !!!")
//              }
//            } else {
//              metaData.absolute(colNumbsers + i)
//              prepareStatement.setNull(colNumbsers + i, metaData.getInt("data_Type"))
//            }
//          }
          prepareStatement.addBatch()
          count += 1
        })
        //批次大小为100
        if (count % 1000 == 0) {
          prepareStatement.executeBatch()
        }
        conn.commit()
      } catch {
        case e: Exception =>
          println(s"@@  ${e.getMessage}")
          status = false
      } finally {
        prepareStatement.executeBatch()
        conn.commit()
        prepareStatement.close()
        conn.close()
      }
    })
    status
  }

  /**
   *
   * @param tableName
   * @param resultDateFrame 要入库的dataframe
   * @param updateColumns   要更新的字段
   */
  def insertOrUpdateDFtoDBUserPool(tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = {
    var status = true
    var count = 0
    val colNumbsers = resultDateFrame.columns.length
    val sql = getInsertOrUpdateSql(tableName, resultDateFrame.columns, updateColumns)
    val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
    println(s"\n$sql")
    resultDateFrame.foreachPartition((partitionRecords:Iterator[Row]) => {
      val conn = MySqlPoolManager.getMysqlManager.getConnection
      val prepareStatement = conn.prepareStatement(sql)
      val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%")
      try {
        conn.setAutoCommit(false)
        partitionRecords.foreach(record => {
          //设置需要插入的字段
          for (i <- 1 to colNumbsers) {
            val value = record.get(i - 1)
            val dateType = columnDataTypes(i - 1)
            if (value != null) {
              prepareStatement.setString(i, value.toString)
              dateType match {
                case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
                case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
                case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
                case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
                case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
                case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
                case _ => throw new RuntimeException("nonsupport $ {dateType} !!!")
              }
            } else {
              metaData.absolute(i)
              prepareStatement.setNull(i, metaData.getInt("Data_Type"))
            }
          }
          //设置需要 更新的字段值
          for (i <- 1 to updateColumns.length) {
            val fieldIndex = record.fieldIndex(updateColumns(i - 1))
            val value = record.get(i)
            val dataType = columnDataTypes(fieldIndex)
            println(s"\n更新字段值属性索引: $fieldIndex,属性值:$value,属性类型:$dataType")
            if (value != null) {
              dataType match {
                case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
                case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex))
                case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex))
                case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex))
                case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex))
                case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex))
                case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex))
                case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex))
                case _ => throw new RuntimeException(s"no support ${dataType} !!!")
              }
            } else {
              metaData.absolute(colNumbsers + i)
              prepareStatement.setNull(colNumbsers + i, metaData.getInt("data_Type"))
            }
          }
          prepareStatement.addBatch()
          count += 1
        })
        //批次大小为100
        if (count % 1000 == 0) {
          prepareStatement.executeBatch()
        }
        conn.commit()
      } catch {
        case e: Exception =>
          println(s"@@  ${e.getMessage}")
          status = false
      } finally {
        prepareStatement.executeBatch()
        conn.commit()
        prepareStatement.close()
        conn.close()
      }
    })
    status
  }

  /**
   * 验证属性是否存在
   */
  def verifyFieldConsistency(tableName: String, df: DataFrame) = {
    val conn = MySqlPoolManager.getMysqlManager.getConnection
    val metaData = conn.getMetaData
    val colResultSet = metaData.getColumns(null, "%", tableName, "%")
    colResultSet.last()
    val tableFieldNum = colResultSet.getRow
    val dfFieldNum = df.columns.length
    if (tableFieldNum != dfFieldNum) {
      throw new Exception(s"mysql表列数量:${tableFieldNum}与DataFrame数量:${dfFieldNum}不一致")
    }
    for (i <- 1 to tableFieldNum) {
      colResultSet.absolute(i)
      val tableFieldName = colResultSet.getString("column_name")
      val dfFieldName = df.columns.apply(i - 1)
      if (!tableFieldName.equals(dfFieldName)) {
        throw new Exception(s"mysql表列名${tableFieldName}与DataFrame列名${dfFieldName}不一致")
      }
    }
    colResultSet.beforeFirst()
  }
}

package edu.xd.spark

import java.io.Serializable
import java.sql.Connection

import com.mchange.v2.c3p0.ComboPooledDataSource

class MySqlPool extends Serializable {

  private val cpds: ComboPooledDataSource = new ComboPooledDataSource(true)

  try {
    cpds.setJdbcUrl(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.url"))
    cpds.setDriverClass(PropertiyUtils.getFileProperties("db.properties", "mysql.driver"))
    cpds.setUser(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.user"))
    cpds.setPassword(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.password"))
    cpds.setMinPoolSize(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.minPoolSize").toInt)
    cpds.setMaxPoolSize(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.maxPoolSize").toInt)
    cpds.setAcquireIncrement(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.acquireIncrement").toInt)
    cpds.setMaxStatements(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.maxStatements").toInt)
  } catch {
    case e: Exception => e.printStackTrace()
  }

  def getConnection: Connection = {
    try {
      cpds.getConnection()
    } catch {
      case ex: Exception =>
        ex.printStackTrace()
        null
    }
  }

  def close() = {
    try {
      cpds.close()
    } catch {
      case ex: Exception =>
        ex.printStackTrace()
    }
  }
}

package edu.xd.spark

import java.util.Properties

object PropertiyUtils {
  def getFileProperties(fileName: String, propertityKey: String): String = {
    val result = this.getClass.getClassLoader.getResourceAsStream(fileName)
    val prop = new Properties()
    prop.load(result)
    prop.getProperty(propertityKey)
  }

  def getDbProperties(key: String): String = {

    val result = this.getClass.getClassLoader.getResourceAsStream("db.properties")
    val prop = new Properties()
    prop.load(result)
    prop.getProperty(key)

  }
}

调用saveOrUpdateAndCreateTableIfNotExits(tableName:String,df:DataFrame)方法会自动创建表,并更新数据

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值