Spark3通过Upsert方式写入MySQL

object HiveToMysqlSingleTable {
  private val logger: Logger = LoggerFactory.getLogger(HiveToMysqlSingleTable.getClass)
  def main(args: Array[String]): Unit = {
    logger.info("HiveToMysqlSingleTable program begin...")
    // 解析传入参数
    ArgsParseUtil.init(args)
    val sql = ArgsParseUtil.getStringValue("sql", "")
    val ip = ArgsParseUtil.getStringValue("mysql_ip", "")
    val port = ArgsParseUtil.getStringValue("mysql_port", "")
    val db = ArgsParseUtil.getStringValue("mysql_db", "")
    val user = ArgsParseUtil.getStringValue("mysql_user", "")
    val password = ArgsParseUtil.getStringValue("mysql_pwd", "")
    val tableName = ArgsParseUtil.getStringValue("mysqlTableName", "")
    val numPartitions = ArgsParseUtil.getIntValue("numPartitions", 100)
    val batchSize = ArgsParseUtil.getStringValue("batchSize", "1000")
    val jdbcUrl = s"jdbc:mysql://$ip:$port/$db?useUnicode=true&characterEncoding=utf-8"
    // 访问hive需要enableHiveSupport
    val session = SparkSession.builder().enableHiveSupport().getOrCreate()

    logger.info(
      s"""
         |sql==>${sql}
         |ip==>${ip}
         |port==>${port}
         |db==>${db}
         |user==>${user}
         |password==>${password}
         |tableName==>${tableName}
         |numPartitions==>${numPartitions}
         |""".stripMargin)

    val importDF: DataFrame = session.sql(s"${sql}")
    // 将hive查询结果写入Mysql
    // 默认100分区,最好根据数据量定
    importDF
      .repartition(numPartitions)
      .write.format("jdbc")
      .mode(SaveMode.Append)
      .options(Map(
        JDBCOptions.JDBC_URL -> jdbcUrl,
        JDBCOptions.JDBC_TABLE_NAME -> tableName,
        "user" -> user,
        "password" -> password,
        "showSql" -> "true",
        JDBCOptions.JDBC_BATCH_INSERT_SIZE -> batchSize,
        JDBCOptions.JDBC_DRIVER_CLASS -> "com.mysql.jdbc.Driver"
      )).update()
  }
  implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
    def update(): Unit = {
      val extraOptionsField: Field = writer.getClass.getDeclaredField("extraOptions")
      val dfField = writer.getClass.getDeclaredField("df")
      val sourceField = writer.getClass.getDeclaredField("source")
      val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
      extraOptionsField.setAccessible(true)
      dfField.setAccessible(true)
      sourceField.setAccessible(true)
      partitioningColumnsField.setAccessible(true)
      val extraOptions = extraOptionsField.get(writer).asInstanceOf[CaseInsensitiveMap[String]]
      val df = dfField.get(writer).asInstanceOf[sql.DataFrame]
      val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
      val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
      logicalPlanField.setAccessible(true)
      var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
      val session = df.sparkSession
      val dataSource = DataSource(
        sparkSession = session,
        className = "com.oppo.ads.scala.utils.MysqlUpdateRelationProvider",
        partitionColumns = partitioningColumns.getOrElse(Nil),
        options = extraOptions.toMap)
      logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
      val queryExecution = session.sessionState.executePlan(logicalPlan)
      SQLExecution.withNewExecutionId(queryExecution)(queryExecution.toRdd)
    }
  }
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.{createConnectionFactory, getInsertStatement, savePartition}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}

import java.sql.Connection

object JdbcUtilsEnhance {
  def updateTable(
                   df: DataFrame,
                   tableSchema: Option[StructType],
                   isCaseSensitive: Boolean,
                   options: JdbcOptionsInWrite): Unit = {
    val url = options.url
    val table = options.table
    val dialect = JdbcDialects.get(url)
    println(s"dialect = ${dialect}")
    val rddSchema = df.schema
    val getConnection: () => Connection = createConnectionFactory(options)
    val batchSize = options.batchSize
    println(s"batchSize = ${batchSize}")
    val isolationLevel = options.isolationLevel

    val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    val repartitionedDF = options.numPartitions match {
      case Some(n) if n <= 0 => throw new IllegalArgumentException(
        s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
          "via JDBC. The minimum value is 1.")
      case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
      case _ => df
    }
    repartitionedDF.rdd.foreachPartition { iterator =>
      savePartition(
        getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
        options)
    }
  }

  def getUpdateStatement(
                          table: String,
                          rddSchema: StructType,
                          tableSchema: Option[StructType],
                          isCaseSensitive: Boolean,
                          dialect: JdbcDialect): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      }
      // The generated insert statement needs to follow rddSchema's column sequence and
      // tableSchema's column names. When appending data into some case-sensitive DBMSs like
      // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
      // RDD column names for user convenience.
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    }
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    val sql = s"""INSERT INTO $table ($columns) VALUES ($placeholders)
       |ON DUPLICATE KEY UPDATE
       |${columns.split(",").map(col => s"$col=VALUES($col)").mkString(",")}
       |""".stripMargin
    println(sql)
    sql
  }
}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.execution.datasources.jdbc.{JdbcOptionsInWrite, JdbcRelationProvider, JdbcUtils}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}


class MysqlUpdateRelationProvider extends JdbcRelationProvider {
  override def createRelation(
                               sqlContext: SQLContext,
                               mode: SaveMode,
                               parameters: Map[String, String],
                               df: DataFrame): BaseRelation = {
    val options = new JdbcOptionsInWrite(parameters)
    val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis

    val conn = JdbcUtils.createConnectionFactory(options)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, options)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, options)
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
//              saveTable(df, tableSchema, isCaseSensitive, options)
              JdbcUtilsEnhance.updateTable(df, tableSchema, isCaseSensitive, options)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, options.table, options)
              createTable(conn, options.table, df.schema, isCaseSensitive, options)
//              saveTable(df, Some(df.schema), isCaseSensitive, options)
              JdbcUtilsEnhance.updateTable(df, Some(df.schema), isCaseSensitive, options)
            }

          case SaveMode.Append =>
            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
//            saveTable(df, tableSchema, isCaseSensitive, options)
            JdbcUtilsEnhance.updateTable(df, tableSchema, isCaseSensitive, options)

          case SaveMode.ErrorIfExists =>
            throw new Exception(
              s"Table or view '${options.table}' already exists. " +
                s"SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
          // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
          // to not save the contents of the DataFrame and to not change the existing data.
          // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(conn, options.table, df.schema, isCaseSensitive, options)
//        saveTable(df, Some(df.schema), isCaseSensitive, options)
        JdbcUtilsEnhance.updateTable(df, Some(df.schema), isCaseSensitive, options)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
}
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值