***测试用例
* 批量保存数据,存在则更新 不存在 则插入
* INSERT INTO test_001 VALUES(?,?,?)* ON conflict ( ID ) DO
* UPDATE SET id=?,NAME =?,age =?;*@author linzhy
*/
object InsertOrUpdateTest {
def main(args: Array[String]): Unit ={
val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local[2]").config("spark.debug.maxToStringFields","100").getOrCreate()
var config = ConfigFactory.load()
val ods_url = config.getString("pg.oucloud_ods.url")
val ods_user = config.getString("pg.oucloud_ods.user")
val ods_password = config.getString("pg.oucloud_ods.password")
val test_001 = spark.read.format("jdbc").option("url", ods_url).option("dbtable","test_001").option("user", ods_user).option("password", ods_password).load()
test_001.createOrReplaceTempView("test_001")
val sql="""
|SELECT * FROM test_001
|""".stripMargin
val dataFrame = spark.sql(sql)//批量保存数据,存在则更新 不存在 则插入
PgSqlUtil.insertOrUpdateToPgsql(dataFrame,spark.sparkContext,"test_001_copy1","id")
spark.stop();}}
insertOrUpdateToPgsql 方法源码
/**
* 批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码
* @param dataFrame
* @param sc
* @param table
* @param id
*/
def insertOrUpdateToPgsql(dataFrame:DataFrame,sc:SparkContext,table:String,id:String): Unit ={
val tableSchema = dataFrame.schema
val columns =tableSchema.fields.map(x => x.name).mkString(",")
val placeholders = tableSchema.fields.map(_ =>"?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on conflict($id) do update set "
val update = tableSchema.fields.map(x =>
x.name.toString +"=?").mkString(",")
val realsql =sql.concat(update)
val conn =connectionPool()
conn.setAutoCommit(false)
val dialect = JdbcDialects.get(conn.getMetaData.getURL)
val broad_ps = sc.broadcast(conn.prepareStatement(realsql))
val numFields = tableSchema.fields.length *2//调用spark中自带的函数 或者 捞出来,获取属性字段与字段类型
val nullTypes = tableSchema.fields.map(f =>getJdbcType(f.dataType, dialect).jdbcNullType)
val setters = tableSchema.fields.map(f =>makeSetter(conn,f.dataType))
var rowCount =0
val batchSize =2000
val updateindex = numFields /2try{
dataFrame.foreachPartition(iterator =>{//遍历批量提交
val ps = broad_ps.value
try{while(iterator.hasNext){
val row = iterator.next()
var i =0while(i < numFields){
i < updateindex match {casetrue=>{if(row.isNullAt(i)){
ps.setNull(i +1,nullTypes(i))}else{setters(i).apply(ps, row, i,0)}}casefalse=>{if(row.isNullAt(i-updateindex)){
ps.setNull(i +1,nullTypes(i-updateindex))}else{setters(i-updateindex).apply(ps, row, i,updateindex)}}}
i = i +1}
ps.addBatch()
rowCount +=1if(rowCount % batchSize ==0){
ps.executeBatch()
rowCount =0}}if(rowCount >0){
ps.executeBatch()}}finally{
ps.close()}})
conn.commit()}catch{case e: Exception =>logError("Error in execution of insert. "+ e.getMessage)
conn.rollback()// insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)}finally{
conn.close()}}