Spark Upsert写入Mysql(scala增强) 无需依赖

import java.sql.Connection
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcRelationProvider, JdbcUtils}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SQLContext, SaveMode}
object DataFrameWriterEnhance {

  implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
    def update(): Unit = {
      val extraOptionsField = writer.getClass.getDeclaredField("org$apache$spark$sql$DataFrameWriter$$extraOptions")
      val dfField = writer.getClass.getDeclaredField("df")
      val sourceField = writer.getClass.getDeclaredField("source")
      val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
      extraOptionsField.setAccessible(true)
      dfField.setAccessible(true)
      sourceField.setAccessible(true)
      partitioningColumnsField.setAccessible(true)
      val extraOptions = extraOptionsField.get(writer).asInstanceOf[scala.collection.mutable.HashMap[String, String]]
      val df = dfField.get(writer).asInstanceOf[DataFrame]
      val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
      val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
      logicalPlanField.setAccessible(true)
      var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
      val session = df.sparkSession
      val dataSource = DataSource(
        sparkSession = session,
        className = "com.xxx.xxx.enhance.mysql.DataFrameWriterEnhance$MysqlUpdateRelationProvider",
        partitionColumns = partitioningColumns.getOrElse(Nil),
        options = extraOptions.toMap)
      logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
      val qe = session.sessionState.executePlan(logicalPlan)
      SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
    }
  }

  class MysqlUpdateRelationProvider extends JdbcRelationProvider {
    override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = {
      val options = new JdbcOptionsInWrite(parameters)
      val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis
      val conn = JdbcUtils.createConnectionFactory(options)()
      try {
        val tableExists = JdbcUtils.tableExists(conn, options)
        if (tableExists) {
          mode match {
            case SaveMode.Overwrite =>
              if (options.isTruncate && JdbcUtils.isCascadingTruncateTable(options.url) == Some(false)) {
                // In this case, we should truncate table and then load.
                JdbcUtils.truncateTable(conn, options)
                val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                updateTable(df, tableSchema, isCaseSensitive, options)
              } else {
                // Otherwise, do not truncate the table, instead drop and recreate it
                JdbcUtils.dropTable(conn, options.table, options)
                JdbcUtils.createTable(conn, df, options)
                updateTable(df, Some(df.schema), isCaseSensitive, options)
              }

            case SaveMode.Append =>
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              updateTable(df, tableSchema, isCaseSensitive, options)

            case SaveMode.ErrorIfExists =>
              throw new Exception(
                s"Table or view '${options.table}' already exists. " +
                  s"SaveMode: ErrorIfExists.")

            case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
          }
        } else {
          JdbcUtils.createTable(conn, df, options)
          updateTable(df, Some(df.schema), isCaseSensitive, options)
        }
      } finally {
        conn.close()
      }

      createRelation(sqlContext, parameters)
    }

    def updateTable(df: DataFrame,
                    tableSchema: Option[StructType],
                    isCaseSensitive: Boolean,
                    options: JdbcOptionsInWrite): Unit = {
      val url = options.url
      val table = options.table
      val dialect = JdbcDialects.get(url)
      val rddSchema = df.schema
      val getConnection: () => Connection = JdbcUtils.createConnectionFactory(options)
      val batchSize = options.batchSize
      val isolationLevel = options.isolationLevel

      val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
      println(updateStmt)
      val repartitionedDF = options.numPartitions match {
        case Some(n) if n <= 0 => throw new IllegalArgumentException(
          s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
            "via JDBC. The minimum value is 1.")
        case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
        case _ => df
      }
      repartitionedDF.rdd.foreachPartition(iterator => JdbcUtils.savePartition(
        getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
        options)
      )
    }

    def getUpdateStatement(table: String,
                           rddSchema: StructType,
                           tableSchema: Option[StructType],
                           isCaseSensitive: Boolean,
                           dialect: JdbcDialect): String = {
      val columns = if (tableSchema.isEmpty) {
        rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
      } else {
        val columnNameEquality = if (isCaseSensitive) {
          org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
        } else {
          org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
        }
        // The generated insert statement needs to follow rddSchema's column sequence and
        // tableSchema's column names. When appending data into some case-sensitive DBMSs like
        // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
        // RDD column names for user convenience.
        val tableColumnNames = tableSchema.get.fieldNames
        rddSchema.fields.map { col =>
          val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
            throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
          }
          dialect.quoteIdentifier(normalizedName)
        }.mkString(",")
      }
      val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
      s"""INSERT INTO $table ($columns) VALUES ($placeholders)
         |ON DUPLICATE KEY UPDATE
         |${columns.split(",").map(col=>s"$col=VALUES($col)").mkString(",")}
         |""".stripMargin
    }
  }
}

使用

import com.xxx.xxx.enhance.mysql.DataFrameWriterEnhance.DataFrameWriterMysqlUpdateEnhance

  def upsert(rawDF: DataFrame, tableName: String = "xxx"): Unit = {
    var df = rawDF
    for (elem <- df.schema.fields) {
      if (elem.dataType == NullType) {
        df = df.withColumn(elem.name, col(elem.name).cast(ShortType))
      }
    }
    df.write
      .format("jdbc")
      .mode(SaveMode.Append)
      .option("driver", "com.mysql.jdbc.Driver")
      .option("url", "")
      .option("user", "")
      .option("password", "")
      .option("dbtable", tableName)
      .option("useSSL","false")
      .option("showSql", "true")
      .update()
  }

做了简化合并,反射做了处理,spark2.4.x scala2.11可用

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值