Spark-mllib模型序列化与反序列化
大家好,我是一拳就能打爆A柱的魔鬼筋肉人
不知道大家有没有这种需求,将训练好的模型保存在数据库中,今天给大家带来的是Spark-mllib的模型的序列化和反序列化。接下来我从下面几个点来记录这个需求:1、序列化,2、反序列化。
1、 main方法流程
其实main方法的流程很简单,简单来说就是训练模型、保存模型、从数据库中反序列化出来模型对象,最后使用模型做一个预测:
def main(args: Array[String]): Unit = {
....
// 将模型保存到DM
saveModel("kmeans", model)
// 重新加载模型
val reloadedModel: KMeansModel = reloadModel("kmeans")
val i = reloadedModel.predict(Vectors.dense(0, 0, 0))
println("ANS:" + i)
}
最终运行结果如下:
loading jdbc Driver ...
loading jdbc Driver ...
ANS:1
在数据库中可以看到分类结果如下:
SQL> select * from kmeansandres;
行号 item class
---------- ------------- -----
1 [0.0,0.0,0.0] 1
2 [0.1,0.1,0.1] 1
3 [0.2,0.2,0.2] 1
4 [9.0,9.0,9.0] 0
5 [9.1,9.1,9.1] 0
6 [9.2,9.2,9.2] 0
所以分类是正确的。
2、序列化
跟着main方法的调用,紧接着就是saveModel方法,saveModel方法负责调用序列化方法model2BytesArray,并将Byte数组插入DM表中:
def saveModel(modelName: String, model: KMeansModel): Unit = {
// 将对象转bytes数组
val modelBytesArray: Array[Byte] = model2BytesArray(model)
val utils: DMUtils = new DMUtils()
val conn: Connection = utils.getConnection
var sql = "INSERT INTO jc.KmeansModel3 VALUES (?,?,?)"
val pstmt: PreparedStatement = conn.prepareStatement(sql)
try {
val date = new Date()
pstmt.setLong(1, date.getTime)
pstmt.setString(2, modelName)
pstmt.setBytes(3, modelBytesArray)
pstmt.executeUpdate()
} catch {
case ex: Exception => {
println("insert model operation fail ...")
ex.printStackTrace()
}
} finally {
// 释放资源
pstmt.close()
conn.close()
}
}
model2BytesArray方法负责将模型序列化成Array[Byte]:
def model2BytesArray(model: KMeansModel):Array[Byte] = {
var baos: ByteArrayOutputStream = null
var oos: ObjectOutputStream = null
var modelArray: Array[Byte] = null
try {
baos = new ByteArrayOutputStream()
oos = new ObjectOutputStream(baos)
oos.writeObject(model)
modelArray = baos.toByteArray()
} catch {
case ex: Exception => {
println("serialize fail ...")
ex.printStackTrace()
}
} finally {
oos.close()
baos.close()
}
modelArray
}
3、反序列化
其实在反序列化的过程我遇上了一点麻烦,没有根据数据库的数据类型来定义,所以导致反序列化过程中报错。根据main方法的流程,就到reloadModel方法,reloadModel通过sql语句获取模型的Byte数组,并调用反序列化方法bytesArray2Model得到KMeansModel:
def reloadModel(modelName: String): KMeansModel = {
val utils: DMUtils = new DMUtils()
val conn: Connection = utils.getConnection
var sql = "select * from kmeansmodel3 where modelName='kmeans';"
val pstmt: PreparedStatement = conn.prepareStatement(sql)
var model: KMeansModel = null
try {
val set: ResultSet = pstmt.executeQuery()
while (set.next()) {
val bytes: Array[Byte] = set.getBytes(3)
bytes.foreach(print(_))
model = bytesArray2Model(bytes)
}
} catch {
case ex: Exception => {}
println("select model operation fail ...")
ex.printStackTrace()
} finally {
// 释放资源
pstmt.close()
conn.close()
}
model
}
具体的反序列化如下:
def bytesArray2Model(bytesArray: Array[Byte]) = {
var bais: ByteArrayInputStream = null
var ois: ObjectInputStream = null
var model: KMeansModel = null
try {
bais = new ByteArrayInputStream(bytesArray)
ois = new ObjectInputStream(bais)
model = ois.readObject().asInstanceOf[KMeansModel]
} catch {
case ex: Exception => {
println("deserialize fail ...")
ex.printStackTrace()
}
} finally {
ois.close()
bais.close()
}
model
}