spark mysql 写_如何让spark sql写mysql的时候支持update操作

本文介绍了如何让Spark SQL在对接MySQL时支持UPDATE操作。通过分析Spark的源码,了解到其默认不支持UPDATE,但可以通过改造`DataFrameWriter`和`JdbcDialect`,利用MySQL的`ON DUPLICATE KEY UPDATE`特性,自定义SQL语句实现更新功能。改造源码的关键在于插入语句中加入判断,根据不同的SaveMode决定执行INSERT或UPDATE操作,并调整setter的数量以适应UPDATE时的占位符数量变化。
摘要由CSDN通过智能技术生成

8684828f735b

如何让sparkSQL在对接mysql的时候,除了支持:Append、Overwrite、ErrorIfExists、Ignore;还要在支持update操作

1、首先了解背景

spark提供了一个枚举类,用来支撑对接数据源的操作模式

8684828f735b

spark支持的save操作

通过源码查看,很明显,spark是不支持update操作的

2、如何让sparkSQL支持update

关键的知识点就是:

我们正常在sparkSQL写数据到mysql的时候:

大概的api是:

dataframe.write

.format("sql.execution.customDatasource.jdbc")

.option("jdbc.driver", "com.mysql.jdbc.Driver")

.option("jdbc.url", "jdbc:mysql://localhost:3306/test?user=root&password=&useUnicode=true&characterEncoding=gbk&autoReconnect=true&failOverReadOnly=false")

.option("jdbc.db", "test")

.save()

那么在底层中,spark会通过JDBC方言JdbcDialect , 将我们要插入的数据翻译成:

insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)

那么通过方言解析出的sql语句就通过PrepareStatement的executeBatch(),将sql语句提交给mysql,然后数据插入;

那么上面的sql语句很明显,完全就是插入代码,并没有我们期望的 update操作,类似:

UPDATE table_name SET field1=new-value1, field2=new-value2

但是mysql独家支持这样的sql语句:

INSERT INTO student (columns_1,columns_2)VALUES ('第一个字段值','第二个字段值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵哒',columns_2 = '哈哈哒';

大概的意思就是,如果数据不存在则插入,如果数据存在,则 执行update操作;

因此,我们的切入点就是,让sparkSQL内部对接JdbcDialect的时候,能够生成这种sql:

INSERT INTO 表名称 (columns_1,columns_2)VALUES ('第一个字段值','第二个字段值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵哒',columns_2 = '哈哈哒';

3、改造源码前,需要了解整体的代码设计和执行流程

首先是:

dataframe.write

调用write方法就是为了返回一个类:

DataFrameWriter

主要是因为DataFrameWriter是sparksql对接外部数据源写入的入口携带类,下面这些内容是给DataFrameWriter注册的携带信息

8684828f735b

携带信息

然后在出发save()操作后,就开始将数据写入;

接下来看save()源码:

8684828f735b

save源码

在上面的源码里面主要是注册DataSource实例,然后使用DataSource的write方法进行数据写入

实例化DataSource的时候:

def save(): Unit = {

assertNotBucketed("save")

val dataSource = DataSource(

df.sparkSession,

className = source,//自定义数据源的包路径

partitionColumns = partitioningColumns.getOrElse(Nil),//分区字段

bucketSpec = getBucketSpec,//分桶(用于hive)

options = extraOptions.toMap)//传入的注册信息

//mode:插入数据方式SaveMode , df:要插入的数据

dataSource.write(mode, df)

}

然后就是dataSource.write(mode, df)的细节,整段的逻辑就是:

根据providingClass.newInstance()去做模式匹配,然后匹配到哪里,就执行哪里的代码;

8684828f735b

dataSource.write(mode, df)

然后看下providingClass是什么:

8684828f735b

providingClass

8684828f735b

loadupDataSource

拿到包路径.DefaultSource之后,程序进入:

8684828f735b

image.png

那么如果是数据库作为写入目标的话,就会走:dataSource.createRelation,直接跟进源码:

8684828f735b

image.png

很明显是个特质,因此哪里实现了特质,程序就会走到哪里了;

实现这个特质的地方就是:包路径.DefaultSource , 然后就在这里面去实现数据的插入和update的支持操作;

4、改造源码

根据代码的流程,最终sparkSQL 将数据写入mysql的操作,会进入:包路径.DefaultSource这个类里面;

也就是说,在这个类里面既要支持spark的正常插入操作(SaveMode),还要在支持update;

如果让sparksql支持update操作,最关键的就是做一个判断,比如:

if(isUpdate){

sql语句:INSERT INTO student (columns_1,columns_2)VALUES ('第一个字段值','第二个字段值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵哒',columns_2 = '哈哈哒';

}else{

insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)

}

但是,在spark生产sql语句的源码中,是这样写的:

8684828f735b

image.png

没有任何的判断逻辑,就是最后生成一个:

INSERT INTO TABLE (字段1 , 字段2....) VALUES (? , ? ...)

所以首要的任务就是 ,怎么能让当前代码支持:

ON DUPLICATE KEY UPDATE

可以做个大胆的设计,就是在insertStatement这个方法中做个如下的判断

def insertStatement(conn: Connection, savemode:CustomSaveMode , table: String, rddSchema: StructType, dialect: JdbcDialect)

: PreparedStatement = {

val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")

val placeholders = rddSchema.fields.map(_ => "?").mkString(",")

if(savemode == CustomSaveMode.update){

//TODO 如果是update,就组装成ON DUPLICATE KEY UPDATE的模式处理

s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"

}esle{

val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"

conn.prepareStatement(sql)

}

}

这样,在用户传递进来的savemode模式,我们进行校验,如果是update操作,就返回对应的sql语句!

所以按照上面的逻辑,我们代码这样写:

8684828f735b

image.png

这样我们就拿到了对应的sql语句;

但是只有这个sql语句还是不行的,因为在spark中会执行jdbc的prepareStatement操作,这里面会涉及到游标。

即jdbc在遍历这个sql的时候,源码会这样做:

8684828f735b

image.png

看下makeSetter:

8684828f735b

image.png

所谓有坑就是:

insert into table (字段1 , 字段2, 字段3) values (? , ? , ?)

那么当前在源码中返回的数组长度应该是3:

val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)

.map(makeSetter(conn, dialect, _)).toArray

但是如果我们此时支持了update操作,既:

insert into table (字段1 , 字段2, 字段3) values (? , ? , ?) ON DUPLICATE KEY UPDATE 字段1 = ?,字段2 = ?,字段3=?;

那么很明显,上面的sql语句提供了6个? , 但在规定字段长度的时候只有3

8684828f735b

image.png

这样的话,后面的update操作就无法执行,程序报错!

所以我们需要有一个 识别机制,既:

if(isupdate){

val numFields = rddSchema.fields.length * 2

}else{

val numFields = rddSchema.fields.length

}

8684828f735b

image.png

row[1,2,3] setter(0,1) //index of setter , index of row setter(1,2) setter(2,3) setter(3,1) setter(4,2) setter(5,3)

所以在prepareStatment中的占位符应该是row的两倍,而且应该是类似这样的一个逻辑

因此,代码改造前样子:

8684828f735b

image.png

8684828f735b

image.png

改造后的样子:

try {

if (supportsTransactions) {

conn.setAutoCommit(false) // Everything in the same db transaction.

conn.setTransactionIsolation(finalIsolationLevel)

}

// val stmt = insertStatement(conn, table, rddSchema, dialect)

//此处采用最新自己的sql语句,封装成prepareStatement

val stmt = conn.prepareStatement(sqlStmt)

println(sqlStmt)

/**

* 在mysql中有这样的操作:

* INSERT INTO user_admin_t (_id,password) VALUES ('1','第一次插入的密码')

* INSERT INTO user_admin_t (_id,password)VALUES ('1','第一次插入的密码') ON DUPLICATE KEY UPDATE _id = 'UpId',password = 'upPassword';

* 如果是下面的ON DUPLICATE KEY操作,那么在prepareStatement中的游标会扩增一倍

* 并且如果没有update操作,那么他的游标是从0开始计数的

* 如果是update操作,要算上之前的insert操作

* */

//makeSetter也要适配update操作,即游标问题

val isUpdate = saveMode == CustomSaveMode.Update

val setters: Array[JDBCValueSetter] = isUpdate match {

case true =>

val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)

.map(makeSetter(conn, dialect, _)).toArray

Array.fill(2)(setters).flatten

case _ =>

rddSchema.fields.map(_.dataType)

.map(makeSetter(conn, dialect, _)).toArray

}

val numFieldsLength = rddSchema.fields.length

val numFields = isUpdate match{

case true => numFieldsLength *2

case _ => numFieldsLength

}

val cursorBegin = numFields / 2

try {

var rowCount = 0

while (iterator.hasNext) {

val row = iterator.next()

var i = 0

while (i < numFields) {

if(isUpdate){

//需要判断当前游标是否走到了ON DUPLICATE KEY UPDATE

i < cursorBegin match{

//说明还没走到update阶段

case true =>

//row.isNullAt 判空,则设置空值

if (row.isNullAt(i)) {

stmt.setNull(i + 1, nullTypes(i))

} else {

setters(i).apply(stmt, row, i, 0)

}

//说明走到了update阶段

case false =>

if (row.isNullAt(i - cursorBegin)) {

//pos - offset

stmt.setNull(i + 1, nullTypes(i - cursorBegin))

} else {

setters(i).apply(stmt, row, i, cursorBegin)

}

}

}else{

if (row.isNullAt(i)) {

stmt.setNull(i + 1, nullTypes(i))

} else {

setters(i).apply(stmt, row, i ,0)

}

}

//滚动游标

i = i + 1

}

stmt.addBatch()

rowCount += 1

if (rowCount % batchSize == 0) {

stmt.executeBatch()

rowCount = 0

}

}

if (rowCount > 0) {

stmt.executeBatch()

}

} finally {

stmt.close()

}

if (supportsTransactions) {

conn.commit()

}

committed = true

Iterator.empty

} catch {

case e: SQLException =>

val cause = e.getNextException

if (cause != null && e.getCause != cause) {

if (e.getCause == null) {

e.initCause(cause)

} else {

e.addSuppressed(cause)

}

}

throw e

} finally {

if (!committed) {

// The stage must fail. We got here through an exception path, so

// let the exception through unless rollback() or close() want to

// tell the user about another problem.

if (supportsTransactions) {

conn.rollback()

}

conn.close()

} else {

// The stage must succeed. We cannot propagate any exception close() might throw.

try {

conn.close()

} catch {

case e: Exception => logWarning("Transaction succeeded, but closing failed", e)

}

}

// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for

// `PreparedStatement`. The last argument `Int` means the index for the value to be set

// in the SQL statement and also used for the value in `Row`.

//PreparedStatement, Row, position , cursor

private type JDBCValueSetter = (PreparedStatement, Row, Int , Int) => Unit

private def makeSetter(

conn: Connection,

dialect: JdbcDialect,

dataType: DataType): JDBCValueSetter = dataType match {

case IntegerType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setInt(pos + 1, row.getInt(pos - cursor))

case LongType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setLong(pos + 1, row.getLong(pos - cursor))

case DoubleType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setDouble(pos + 1, row.getDouble(pos - cursor))

case FloatType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setFloat(pos + 1, row.getFloat(pos - cursor))

case ShortType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setInt(pos + 1, row.getShort(pos - cursor))

case ByteType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setInt(pos + 1, row.getByte(pos - cursor))

case BooleanType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setBoolean(pos + 1, row.getBoolean(pos - cursor))

case StringType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

// println(row.getString(pos))

stmt.setString(pos + 1, row.getString(pos - cursor))

case BinaryType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - cursor))

case TimestampType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - cursor))

case DateType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - cursor))

case t: DecimalType =>

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

stmt.setBigDecimal(pos + 1, row.getDecimal(pos - cursor))

case ArrayType(et, _) =>

// remove type length parameters from end of type name

val typeName = getJdbcType(et, dialect).databaseTypeDefinition

.toLowerCase.split("\\(")(0)

(stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>

val array = conn.createArrayOf(

typeName,

row.getSeq[AnyRef](pos - cursor).toArray)

stmt.setArray(pos + 1, array)

case _ =>

(_: PreparedStatement, _: Row, pos: Int,cursor:Int) =>

throw new IllegalArgumentException(

s"Can't translate non-null value for field $pos")

}

Spark中,可以通过JDBC连接MySQL数据库并使用`upsert`语句来执行插入或更新操作。具体实现如下: 1. 导入必要的依赖: ```xml <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>8.0.23</version> </dependency> ``` 2. 创建一个`DataFrame`,并指定`DataFrame`的schema。假设需要将数据插入到名为`users`的MySQL表中,表结构如下: ```sql CREATE TABLE users ( id INT NOT NULL AUTO_INCREMENT, name VARCHAR(100), age INT, PRIMARY KEY (id) ); ``` 则可以定义如下的schema: ```scala import org.apache.spark.sql.types._ val schema = StructType(Seq( StructField("name", StringType), StructField("age", IntegerType) )) ``` 3. 读取数据并将其转换为`DataFrame`: ```scala val rdd = sc.parallelize(Seq( ("Alice", 25), ("Bob", 30), ("Charlie", 35) )) val df = spark.createDataFrame(rdd).toDF("name", "age") ``` 4. 将`DataFrame`入到MySQL表中: ```scala val url = "jdbc:mysql://localhost:3306/mydb" val user = "username" val password = "password" df.write .format("jdbc") .option("url", url) .option("dbtable", "users") .option("user", user) .option("password", password) .option("driver", "com.mysql.jdbc.Driver") .option("rewriteBatchedStatements", "true") .option("batchsize", "10000") .mode("append") .save() ``` 在上述代码中,`url`用于指定MySQL数据库的连接地址,`user`和`password`用于指定数据库的用户名和密码,`dbtable`用于指定要入的表名,`driver`用于指定MySQL的JDBC驱动程序。 `rewriteBatchedStatements`和`batchsize`用于优化入性能。`rewriteBatchedStatements`设置为`true`时,表示使用批量入模式,可以提高入性能。`batchsize`用于指定每批次入的记录数。 5. 如果需要执行`upsert`操作,则可以使用MySQL的`REPLACE INTO`语句或`ON DUPLICATE KEY UPDATE`语句。例如,如果需要根据`name`字段更新记录,则可以使用如下的SQL语句: ```sql INSERT INTO users (name, age) VALUES (?, ?) ON DUPLICATE KEY UPDATE age=VALUES(age) ``` 在Spark中,可以通过以下方式执行`upsert`操作: ```scala df.write .format("jdbc") .option("url", url) .option("dbtable", "users") .option("user", user) .option("password", password) .option("driver", "com.mysql.jdbc.Driver") .option("rewriteBatchedStatements", "true") .option("batchsize", "10000") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .option("rewriteBatchedStatements", "true") .mode("append") .jdbc(url, "users", prop) ``` 在上述代码中,`prop`是一个包含`user`和`password`属性的`java.util.Properties`对象。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值