SparkSql批量插入或更新 保存数据到Pgsql

  • 在sparksql 中,保存数据到数据,只有 Append , Overwrite , ErrorIfExists, Ignore 四种模式,不满足项目需求 ,现依据 spark save 源码,进行进一步的改造, 批量保存数据,存在则更新 不存在 则插入
**
 *测试用例
 *    批量保存数据,存在则更新 不存在 则插入
 *    INSERT INTO test_001 VALUES( ?, ?, ? )
 *    ON conflict ( ID ) DO
 *    UPDATE SET id=?,NAME = ?,age = ?;
 * @author linzhy
 */
object InsertOrUpdateTest {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName)
      .master("local[2]")
      .config("spark.debug.maxToStringFields","100")
      .getOrCreate()

    var config = ConfigFactory.load()
    val ods_url = config.getString("pg.oucloud_ods.url")
    val ods_user = config.getString("pg.oucloud_ods.user")
    val ods_password = config.getString("pg.oucloud_ods.password")

    val test_001 = spark.read.format("jdbc")
      .option("url", ods_url)
      .option("dbtable", "test_001")
      .option("user", ods_user)
      .option("password", ods_password)
      .load()

    test_001.createOrReplaceTempView("test_001")

    val sql=
      """
        |SELECT * FROM test_001
        |""".stripMargin

    val dataFrame = spark.sql(sql)
    //批量保存数据,存在则更新 不存在 则插入
    PgSqlUtil.insertOrUpdateToPgsql(dataFrame,spark.sparkContext,"test_001_copy1","id")

    spark.stop();
  }
}
  • insertOrUpdateToPgsql 方法源码
/**
   * 批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码
   * @param dataFrame
   * @param sc
   * @param table
   * @param id
   */
  def insertOrUpdateToPgsql(dataFrame:DataFrame,sc:SparkContext,table:String,id:String): Unit ={
    
    val tableSchema = dataFrame.schema
    val columns =tableSchema.fields.map(x => x.name).mkString(",")
    val placeholders = tableSchema.fields.map(_ => "?").mkString(",")
    val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on conflict($id) do update set "
    val update = tableSchema.fields.map(x =>
      x.name.toString + "=?"
    ).mkString(",")

    val realsql =sql.concat(update)
    val conn =connectionPool()
    conn.setAutoCommit(false)
    val dialect = JdbcDialects.get(conn.getMetaData.getURL)
    val broad_ps = sc.broadcast(conn.prepareStatement(realsql))

    val numFields = tableSchema.fields.length *2
	//调用spark中自带的函数 或者 捞出来,获取属性字段与字段类型
    val nullTypes = tableSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val setters = tableSchema.fields.map(f => makeSetter(conn,f.dataType))
    
    var rowCount = 0
    val batchSize = 2000
    val updateindex = numFields / 2
    try {
        dataFrame.foreachPartition(iterator =>{
          //遍历批量提交
          val ps = broad_ps.value
          try{
            while (iterator.hasNext) {
              val row = iterator.next()
              var i = 0
              while (i < numFields) {
                i < updateindex match {
                  case true =>{
                    if (row.isNullAt(i)) {
                      ps.setNull(i + 1, nullTypes(i))
                    } else {
                      setters(i).apply(ps, row, i,0)
                    }
                  }
                  case false =>{
                    if (row.isNullAt(i-updateindex)) {
                      ps.setNull(i + 1, nullTypes(i-updateindex))
                    } else {
                      setters(i-updateindex).apply(ps, row, i,updateindex)
                    }
                  }
                }
                i = i + 1
              }
              ps.addBatch()
              rowCount += 1
              if (rowCount % batchSize == 0) {
                ps.executeBatch()
                rowCount = 0
              }
            }
            if (rowCount > 0) {
              ps.executeBatch()
            }
          }finally {
            ps.close()
          }
        })
      conn.commit()
    }catch {
      case e: Exception =>
        logError("Error in execution of insert. " + e.getMessage)
      conn.rollback()
       // insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)
    }finally {
      conn.close()
    }
  }
  • 从spark 源码中捞出 getJdbcType /makeSetter函数
 private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
    dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
      throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}"))
  }

 private type JDBCValueSetter_add = (PreparedStatement, Row, Int,Int) => Unit
  
  private def makeSetter(conn: Connection, dataType: DataType): JDBCValueSetter_add = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setInt(pos + 1, row.getInt(pos-currentpos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setLong(pos + 1, row.getLong(pos-currentpos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos-currentpos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos-currentpos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setInt(pos + 1, row.getShort(pos-currentpos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setInt(pos + 1, row.getByte(pos-currentpos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos-currentpos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setString(pos + 1, row.getString(pos-currentpos))

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos-currentpos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos-currentpos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos-currentpos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos-currentpos))
    case _ =>
      (stmt: PreparedStatement, row: Row, pos: Int,currentpos:Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值