背景
目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update
这样的形式来实现。
实践
要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:
public enum I4SaveMode {
Append,
Overwrite,
ErrorIfExists,
Ignore,
Update
}
JDBC数据源的相关实现主要在JdbcRelationProvider
里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
df: DataFrame): BaseRelation = {
val options = new JDBCOptions(parameters)
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
// 替换成自己的saveMode
var saveMode = mode match {
case SaveMode.Overwrite => I4SaveMode.Overwrite
case SaveMode.Append => I4SaveMode.Append
case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
case SaveMode.Ignore => I4SaveMode.Ignore
}
//重点在这里,检查是否有saveMode=update的参数,并设为对应的模式
val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
if(parameterLower.keySet.contains("savemode")){
saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
}
val conn = JdbcUtils.createConnectionFactory(options)()
try {
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
saveMode match {
case I4SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, options.table)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
} else {
......
}
接下来就是saveTable方法:
def saveTable(
df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JDBCOptions,
mode: I4SaveMode): Unit = {
......
val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
.....
repartitionedDF.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
)
}
这里通过getInsertStatement