本例子是在网上资源的完善,修改了网上资源的几处bug,加入新的函数
package edu.xd.spark
import java.sql.{Date, Timestamp}
import java.util.Properties
import org.apache.log4j.Logger
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
object OperatorMySql {
val logger: Logger = Logger.getLogger(this.getClass.getSimpleName)
/**
* 将dataframe所有类型(除id外)转换为string后,通过c3p0的连接池方式,向mysql写入数据
*
* @param tableName 表名
* @param resultDateFrame datafream
*/
def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame): Unit = {
val colNumbsers = resultDateFrame.columns.length
val sql = getInsertSql(tableName, colNumbsers)
val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
resultDateFrame.foreachPartition(
(partitionRecords:Iterator[Row]) => {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val prepareStatement = conn.prepareStatement(sql)
val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%")
try {
conn.setAutoCommit(false)
partitionRecords.foreach(record => {
for (i <- 1 to colNumbsers) {
val value = record.get(i - 1)
val dateType = columnDataTypes(i - 1)
if (value != null) {
//prepareStatement.setString(i, value.toString)
dateType match {
case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
case _ => throw new RuntimeException("nonsupport $ {dateType} !!!")
}
} else {
metaData.absolute(i)
prepareStatement.setNull(i, metaData.getInt("DATA_TYPE"))
}
}
prepareStatement.addBatch()
})
prepareStatement.executeBatch()
conn.commit()
} catch {
case e: Exception => println(s"@@ saveDFtoDBUsePool ${e.getMessage}")
} finally {
prepareStatement.close()
conn.close()
}
})
}
/**
* 拼接sql
*/
def getInsertSql(tableName: String, colNumbers: Int): String = {
var sqlStr = "insert into " + tableName + " values("
for (i <- 1 to colNumbers) {
sqlStr += "?"
if (i != colNumbers) {
sqlStr += ","
}
}
sqlStr += ")"
sqlStr
}
/**
* 以元祖的额方式返回mysql属性信息
*
* @return
*/
def getMysqlInfo: (String, String, String) = {
val jdbcURL = PropertiyUtils.getFileProperties("", "")
val userName = PropertiyUtils.getFileProperties("", "")
val password = PropertiyUtils.getFileProperties("", "")
(jdbcURL, userName, password)
}
/**
* 从mysql中获取dataframe
*
* @param sqlContext sqlContext
* @param mysqlTableName 表名
* @param queryCondition 查询条件
* @return
*/
def getDFFromeMysql(sqlContext: SQLContext, mysqlTableName: String, queryCondition: String = ""): DataFrame = {
val (jdbcURL, userName, password) = getMysqlInfo
val prop = new Properties()
prop.put("user", userName)
prop.put("password", password)
//scala中其实equals和==是相同的,并不跟java中一样
if (null == queryCondition || "" == queryCondition) {
sqlContext.read.jdbc(jdbcURL, mysqlTableName, prop)
} else {
sqlContext.read.jdbc(jdbcURL, mysqlTableName, prop).where(queryCondition)
}
}
/**
* 删除数据表
* @param mysqlTableName
* @return
*/
def dropMysqlTable( mysqlTableName: String): Boolean = {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val preparedStatement = conn.createStatement()
try {
preparedStatement.execute(s"drop table $mysqlTableName")
} catch {
case e: Exception =>
println(s"mysql drop MysqlTable error:${e.getMessage}")
false
} finally {
preparedStatement.close()
conn.close()
}
}
/**
* 从表中删除数据
*
* @param SQLContext
* @param mysqlTableName 表名
* @param condition 条件,直接从where后面开始
* @return
*/
def deleteMysqlTableData(SQLContext: SQLContext, mysqlTableName: String, condition: String): Boolean = {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val preparedStatement = conn.createStatement()
try {
preparedStatement.execute(s"delete from $mysqlTableName where $condition")
} catch {
case e: Exception =>
println(s"mysql delete MysqlTableNameData error:${e.getMessage}")
false
} finally {
preparedStatement.close()
conn.close()
}
}
/**
* 保存dataframe到mysql中,如果表不存在的话,会自动创建
*
* @param tableName
* @param resultDataFrame
*/
def saveDFtoDBCreateTableIfNotExists(tableName: String, resultDataFrame: DataFrame) = {
//如果没有表,根据dataframe建表
createTableIfNotExist(tableName, resultDataFrame)
//验证数据表字段和dataframe字段个数和名称,顺序是否一致
verifyFieldConsistency(tableName, resultDataFrame)
//保存df
saveDFtoDBUsePool(tableName, resultDataFrame)
}
/**
* jiakai
* @param tableName
* @param df
*/
def saveOrUpdateAndCreateTableIfNotExits(tableName:String,df:DataFrame): Unit = {
createTableIfNotExist(tableName, df)
val conn = MySqlPoolManager.getMysqlManager.getConnection
val metaData = conn.getMetaData
val colResultSet = metaData.getColumns(null, "%", tableName, "%")
insertOrUpdateDFtoDBUserPoolWithMysqlSaveUpdateNew(tableName,df,df.columns)
}
/**
* 如果表不存在则创建
*
* @param tableName
* @param df
* @return
*/
def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val metaData = conn.getMetaData
val colResultSet = metaData.getColumns(null, "%", tableName, "%")
//如果没有该表,创建数据表
if (!colResultSet.next()) {
//构建表字符串
val sb = new StringBuilder(s"CREATE TABLE if not exists `$tableName` (")
df.schema.fields.foreach(x => {
if (x.name.equalsIgnoreCase("id")) {
//如果字段名是id,则设置为主键,不为空,自增
sb.append(s"`${x.name}` int(255) not null auto_increment primary key,")
} else {
x.dataType match {
case _: ByteType => sb.append(s"`${x.name}` int(10) default null,")
case _: ShortType => sb.append(s"`${x.name}` int(10) default null,")
case _: IntegerType => sb.append(s"`${x.name}` int(20) default null,")
case _: LongType => sb.append(s"`${x.name}` bigint(20) default null,")
case _: BooleanType => sb.append(s"`${x.name}` tinyint default null,")
case _: FloatType => sb.append(s"`${x.name}` float default null,")
case _: DoubleType => sb.append(s"`${x.name}` double default null,")
case _: StringType => sb.append(s"`${x.name}` varchar(64) default null,")
case _: TimestampType => sb.append(s"`${x.name}` timestamp default current_timestamp,")
case _: DateType => sb.append(s"`${x.name}` date default null,")
case _ => throw new RuntimeException(s"non support ${x.dataType}!!!")
}
}
})
sb.append(") engine = InnDB default charset=utf8mb4" )
val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString()
println(sql_createTable)
val statement = conn.createStatement()
statement.execute(sql_createTable)
}
}
/**
* 拼接insertOrUpdate语句
*
* @param tableName
* @param cols
* @param updateColumns
* @return
*/
def getInsertOrUpdateSqlNew(tableName: String, cols: Array[String], updateColumns: Array[String]): String = {
val colNumbers = cols.length
var sqlStr = "insert into " + tableName + "("
for (i <- 1 to colNumbers) {
sqlStr += cols(i - 1)
if (i != colNumbers) {
sqlStr += ","
}
}
sqlStr += ") values("
for (i <- 1 to colNumbers) {
sqlStr += "?"
if (i != colNumbers) {
sqlStr += ","
}
}
sqlStr += ") on duplicate key update "
updateColumns.foreach(str => {
sqlStr += s"$str=values($str),"
})
sqlStr.substring(0, sqlStr.length - 1)
}
/**
* 拼接insertOrUpdate语句
*
* @param tableName
* @param cols
* @param updateColumns
* @return
*/
def getInsertOrUpdateSql(tableName: String, cols: Array[String], updateColumns: Array[String]): String = {
val colNumbers = cols.length
var sqlStr = "insert into " + tableName + "("
for (i <- 1 to colNumbers) {
sqlStr += cols(i - 1)
if (i != colNumbers) {
sqlStr += ","
}
}
sqlStr += ") values("
for (i <- 1 to colNumbers) {
sqlStr += "?"
if (i != colNumbers) {
sqlStr += ","
}
}
sqlStr += ") on duplicate key update "
updateColumns.foreach(str => {
sqlStr += s"$str=?,"
})
sqlStr.substring(0, sqlStr.length - 1)
}
/**
* jiakai使用mysql的saveUpdate
* @param tableName
* @param resultDateFrame 要入库的dataframe
* @param updateColumns 要更新的字段
*/
def insertOrUpdateDFtoDBUserPoolWithMysqlSaveUpdateNew(tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = {
var status = true
var count = 0
val colNumbsers = resultDateFrame.columns.length
val sql = getInsertOrUpdateSqlNew(tableName, resultDateFrame.columns, updateColumns)
val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
println(s"\n$sql")
resultDateFrame.foreachPartition((partitionRecords:Iterator[Row]) => {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val prepareStatement = conn.prepareStatement(sql)
val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%")
try {
conn.setAutoCommit(false)
partitionRecords.foreach(record => {
//设置需要插入的字段
for (i <- 1 to colNumbsers) {
val value = record.get(i - 1)
val dateType = columnDataTypes(i - 1)
if (value != null) {
prepareStatement.setString(i, value.toString)
dateType match {
case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
case _ => throw new RuntimeException("nonsupport $ {dateType} !!!")
}
} else {
metaData.absolute(i)
prepareStatement.setNull(i, metaData.getInt("Data_Type"))
}
}
//设置需要 更新的字段值
// for (i <- 1 to updateColumns.length) {
// val fieldIndex = record.fieldIndex(updateColumns(i - 1))
// val value = record.get(i)
// val dataType = columnDataTypes(fieldIndex)
// println(s"\n更新字段值属性索引: $fieldIndex,属性值:$value,属性类型:$dataType")
// if (value != null) {
// dataType match {
// case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
// case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
// case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
// case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex))
// case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex))
// case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex))
// case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex))
// case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex))
// case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex))
// case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex))
// case _ => throw new RuntimeException(s"no support ${dataType} !!!")
// }
// } else {
// metaData.absolute(colNumbsers + i)
// prepareStatement.setNull(colNumbsers + i, metaData.getInt("data_Type"))
// }
// }
prepareStatement.addBatch()
count += 1
})
//批次大小为100
if (count % 1000 == 0) {
prepareStatement.executeBatch()
}
conn.commit()
} catch {
case e: Exception =>
println(s"@@ ${e.getMessage}")
status = false
} finally {
prepareStatement.executeBatch()
conn.commit()
prepareStatement.close()
conn.close()
}
})
status
}
/**
*
* @param tableName
* @param resultDateFrame 要入库的dataframe
* @param updateColumns 要更新的字段
*/
def insertOrUpdateDFtoDBUserPool(tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = {
var status = true
var count = 0
val colNumbsers = resultDateFrame.columns.length
val sql = getInsertOrUpdateSql(tableName, resultDateFrame.columns, updateColumns)
val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
println(s"\n$sql")
resultDateFrame.foreachPartition((partitionRecords:Iterator[Row]) => {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val prepareStatement = conn.prepareStatement(sql)
val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%")
try {
conn.setAutoCommit(false)
partitionRecords.foreach(record => {
//设置需要插入的字段
for (i <- 1 to colNumbsers) {
val value = record.get(i - 1)
val dateType = columnDataTypes(i - 1)
if (value != null) {
prepareStatement.setString(i, value.toString)
dateType match {
case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1))
case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1))
case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1))
case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1))
case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1))
case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1))
case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1))
case _ => throw new RuntimeException("nonsupport $ {dateType} !!!")
}
} else {
metaData.absolute(i)
prepareStatement.setNull(i, metaData.getInt("Data_Type"))
}
}
//设置需要 更新的字段值
for (i <- 1 to updateColumns.length) {
val fieldIndex = record.fieldIndex(updateColumns(i - 1))
val value = record.get(i)
val dataType = columnDataTypes(fieldIndex)
println(s"\n更新字段值属性索引: $fieldIndex,属性值:$value,属性类型:$dataType")
if (value != null) {
dataType match {
case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex))
case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex))
case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex))
case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex))
case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex))
case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex))
case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex))
case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex))
case _ => throw new RuntimeException(s"no support ${dataType} !!!")
}
} else {
metaData.absolute(colNumbsers + i)
prepareStatement.setNull(colNumbsers + i, metaData.getInt("data_Type"))
}
}
prepareStatement.addBatch()
count += 1
})
//批次大小为100
if (count % 1000 == 0) {
prepareStatement.executeBatch()
}
conn.commit()
} catch {
case e: Exception =>
println(s"@@ ${e.getMessage}")
status = false
} finally {
prepareStatement.executeBatch()
conn.commit()
prepareStatement.close()
conn.close()
}
})
status
}
/**
* 验证属性是否存在
*/
def verifyFieldConsistency(tableName: String, df: DataFrame) = {
val conn = MySqlPoolManager.getMysqlManager.getConnection
val metaData = conn.getMetaData
val colResultSet = metaData.getColumns(null, "%", tableName, "%")
colResultSet.last()
val tableFieldNum = colResultSet.getRow
val dfFieldNum = df.columns.length
if (tableFieldNum != dfFieldNum) {
throw new Exception(s"mysql表列数量:${tableFieldNum}与DataFrame数量:${dfFieldNum}不一致")
}
for (i <- 1 to tableFieldNum) {
colResultSet.absolute(i)
val tableFieldName = colResultSet.getString("column_name")
val dfFieldName = df.columns.apply(i - 1)
if (!tableFieldName.equals(dfFieldName)) {
throw new Exception(s"mysql表列名${tableFieldName}与DataFrame列名${dfFieldName}不一致")
}
}
colResultSet.beforeFirst()
}
}
package edu.xd.spark
import java.io.Serializable
import java.sql.Connection
import com.mchange.v2.c3p0.ComboPooledDataSource
class MySqlPool extends Serializable {
private val cpds: ComboPooledDataSource = new ComboPooledDataSource(true)
try {
cpds.setJdbcUrl(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.url"))
cpds.setDriverClass(PropertiyUtils.getFileProperties("db.properties", "mysql.driver"))
cpds.setUser(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.user"))
cpds.setPassword(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.password"))
cpds.setMinPoolSize(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.minPoolSize").toInt)
cpds.setMaxPoolSize(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.maxPoolSize").toInt)
cpds.setAcquireIncrement(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.acquireIncrement").toInt)
cpds.setMaxStatements(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.maxStatements").toInt)
} catch {
case e: Exception => e.printStackTrace()
}
def getConnection: Connection = {
try {
cpds.getConnection()
} catch {
case ex: Exception =>
ex.printStackTrace()
null
}
}
def close() = {
try {
cpds.close()
} catch {
case ex: Exception =>
ex.printStackTrace()
}
}
}
package edu.xd.spark
import java.util.Properties
object PropertiyUtils {
def getFileProperties(fileName: String, propertityKey: String): String = {
val result = this.getClass.getClassLoader.getResourceAsStream(fileName)
val prop = new Properties()
prop.load(result)
prop.getProperty(propertityKey)
}
def getDbProperties(key: String): String = {
val result = this.getClass.getClassLoader.getResourceAsStream("db.properties")
val prop = new Properties()
prop.load(result)
prop.getProperty(key)
}
}
调用saveOrUpdateAndCreateTableIfNotExits(tableName:String,df:DataFrame)方法会自动创建表,并更新数据