以下为flink stream任务中,向mysql插入数据的addSink方式,仅供参考
一、单条处理
package wangjian.sink
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.Properties
import com.alibaba.fastjson.JSONObject
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
import org.slf4j.{Logger, LoggerFactory}
/**
* @author wmy
* @create 2022/11/10 16:15
*/
class StarrocksSinkTumblingShopGoods(properties:Properties) extends RichSinkFunction[JSONObject] {
val driver = properties.getProperty("mysql.driver")
val url = properties.getProperty("mysql.url")
val userName = properties.getProperty("mysql.username")
val passWord = properties.getProperty("mysql.password")
private val LOG: Logger = LoggerFactory.getLogger(classOf[StarrocksSinkTumblingShopGoods])
private var connection: Connection = null
private var ps: PreparedStatement = null
override def open(parameters: Configuration): Unit = {
Class.forName(driver)
connection = DriverManager.getConnection(url, userName, passWord)
val fileds =
"""
|`id`,
|`after_sale`,
|`all_hit`,
|`all_illegal_words`
""".stripMargin
val sql =
s"""
|insert into wj_shop_goods_all (${fileds}) values(?,?,?,?,?)
""".stripMargin
ps = connection.prepareStatement(sql)
}
override def invoke(value: JSONObject): Unit = {
try {
if(connection == null || ps == null){
open(new Configuration)
}
val platformType = if(value.getString("index").contains("otm_content_order_shop_cert")) "waimai" else "dianshang"
val param = value.getJSONObject("data")
ps.setInt(1, param.getInteger("id"))
ps.setString(2, param.getString("after_sale"))
ps.setInt(3, param.getInteger("all_hit"))
ps.setString(4, param.getString("all_illegal_words"))
ps.executeUpdate()
LOG.info("STARROCKS-SINK-SUCCESS:"+param.getString("update_time")+" "+value)
} catch {
case e: Exception => {
LOG.info("STARROCKS-SINK-ERROR:"+e.getMessage+" "+value)
}
}
}
override def close(): Unit = {
if (connection != null) {
connection.close()
}
if (ps != null) {
ps.close()
}
}
}
二、批量处理
package wangjian.sink
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util
import java.util.Properties
import com.alibaba.fastjson.JSONObject
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
import org.slf4j.{Logger, LoggerFactory}
/**
* @author wmy
* @create 2022/11/10 16:15
*/
class StarrocksBatchSinkTumblingShopGoods(properties:Properties) extends RichSinkFunction[JSONObject] {
private val LOG: Logger = LoggerFactory.getLogger(classOf[StarrocksBatchSinkTumblingShopGoods])
val driver = properties.getProperty("mysql.driver")
val url = properties.getProperty("mysql.url")
val userName = properties.getProperty("mysql.username")
val passWord = properties.getProperty("mysql.password")
private val batchData: util.ArrayList[JSONObject] = new util.ArrayList[JSONObject]()
private var connection: Connection = null
private var ps: PreparedStatement = null
override def open(parameters: Configuration): Unit = {
Class.forName(driver)
connection = DriverManager.getConnection(url, userName, passWord)
val fileds =
"""
|`id`,
|`after_sale`,
|`all_hit`,
|`all_illegal_words`
""".stripMargin
val sql =
s"""
|insert into wj_shop_goods_all (${fileds}) values(?,?,?,?)
""".stripMargin
ps = connection.prepareStatement(sql)
}
override def invoke(value: JSONObject): Unit = {
try {
if(connection == null || ps == null){
open(new Configuration)
}
val platformType = if(value.getString("index").contains("otm_content_order_shop_cert")) "waimai" else "dianshang"
val param = value.getJSONObject("data")
param.put("platform_type",platformType)
batchData.add(param)
val batchSize = 1000 // 设置批量大小
if (batchData.size >= batchSize) {
executeBatchInsert() // 执行批量写入操作
}
LOG.info("STARROCKS-BATCH-SINK-SUCCESS:"+param.getString("update_time")+" "+value)
} catch {
case e: Exception => {
LOG.info("STARROCKS-BATCH-SINK-ERROR:"+e.getMessage+" "+value)
}
}
}
private def executeBatchInsert(): Unit = {
val batchSize = batchData.size
if (batchSize > 0) {
for (i <- 0 until batchSize) {
val param = batchData.get(i)
// 设置 PreparedStatement 中的参数
ps.setInt(1, param.getInteger("id"))
ps.setString(2, param.getString("after_sale"))
ps.setInt(3, param.getInteger("all_hit"))
ps.setString(4, param.getString("all_illegal_words"))
ps.addBatch() // 添加当前行数据到批量操作中
}
ps.executeBatch() // 执行批量写入操作
connection.commit() // 提交事务
batchData.clear() // 清空批量列表
}
}
override def close(): Unit = {
try{
executeBatchInsert() // 处理最后的批量写入
}finally {
if (connection != null) {
connection.close()
}
if (ps != null) {
ps.close()
}
}
}
}