Spark-mllib模型序列化与反序列化

Spark-mllib模型序列化与反序列化

大家好,我是一拳就能打爆A柱的魔鬼筋肉人

不知道大家有没有这种需求,将训练好的模型保存在数据库中,今天给大家带来的是Spark-mllib的模型的序列化和反序列化。接下来我从下面几个点来记录这个需求:1、序列化,2、反序列化。

1、 main方法流程

其实main方法的流程很简单,简单来说就是训练模型、保存模型、从数据库中反序列化出来模型对象,最后使用模型做一个预测:

def main(args: Array[String]): Unit = {
    ....

    // 将模型保存到DM
    saveModel("kmeans", model)

    // 重新加载模型
    val reloadedModel: KMeansModel = reloadModel("kmeans")
    val i = reloadedModel.predict(Vectors.dense(0, 0, 0))
    println("ANS:" + i)
}

最终运行结果如下:

loading jdbc Driver ...
loading jdbc Driver ...
ANS:1

在数据库中可以看到分类结果如下:

SQL> select * from kmeansandres;

行号     item          class
---------- ------------- -----
1          [0.0,0.0,0.0] 1    
2          [0.1,0.1,0.1] 1    
3          [0.2,0.2,0.2] 1    
4          [9.0,9.0,9.0] 0    
5          [9.1,9.1,9.1] 0    
6          [9.2,9.2,9.2] 0 

所以分类是正确的。

2、序列化

跟着main方法的调用,紧接着就是saveModel方法,saveModel方法负责调用序列化方法model2BytesArray,并将Byte数组插入DM表中:

def saveModel(modelName: String, model: KMeansModel): Unit = {
    // 将对象转bytes数组
    val modelBytesArray: Array[Byte] = model2BytesArray(model)

    val utils: DMUtils = new DMUtils()
    val conn: Connection = utils.getConnection
    var sql = "INSERT INTO jc.KmeansModel3 VALUES (?,?,?)"
    val pstmt: PreparedStatement = conn.prepareStatement(sql)

    try {
        val date = new Date()
        pstmt.setLong(1, date.getTime)
        pstmt.setString(2, modelName)
        pstmt.setBytes(3, modelBytesArray)
        pstmt.executeUpdate()
    } catch {
        case ex: Exception => {
            println("insert model operation fail ...")
            ex.printStackTrace()
        }
    } finally {
        // 释放资源
        pstmt.close()
        conn.close()
    }
}

model2BytesArray方法负责将模型序列化成Array[Byte]:

def model2BytesArray(model: KMeansModel):Array[Byte] = {
    
    var baos: ByteArrayOutputStream = null
    var oos: ObjectOutputStream = null
    var modelArray: Array[Byte] = null
    try {
        baos = new ByteArrayOutputStream()
        oos = new ObjectOutputStream(baos)
        oos.writeObject(model)
        modelArray = baos.toByteArray()

    } catch {
        case ex: Exception => {
            println("serialize fail ...")
            ex.printStackTrace()
        }
    } finally {
        oos.close()
        baos.close()
    }
    modelArray
}

3、反序列化

其实在反序列化的过程我遇上了一点麻烦,没有根据数据库的数据类型来定义,所以导致反序列化过程中报错。根据main方法的流程,就到reloadModel方法,reloadModel通过sql语句获取模型的Byte数组,并调用反序列化方法bytesArray2Model得到KMeansModel:

def reloadModel(modelName: String): KMeansModel = {
    val utils: DMUtils = new DMUtils()
    val conn: Connection = utils.getConnection
    var sql = "select * from kmeansmodel3 where modelName='kmeans';"
    val pstmt: PreparedStatement = conn.prepareStatement(sql)
    var model: KMeansModel = null

    try {
        val set: ResultSet = pstmt.executeQuery()
        while (set.next()) {
            val bytes: Array[Byte] = set.getBytes(3)
            bytes.foreach(print(_))
            model = bytesArray2Model(bytes)
        }
    } catch {
        case ex: Exception => {}
        println("select model operation fail ...")
        ex.printStackTrace()
    } finally {
        // 释放资源
        pstmt.close()
        conn.close()
    }
    model
}

具体的反序列化如下:

def bytesArray2Model(bytesArray: Array[Byte]) = {
    var bais: ByteArrayInputStream = null
    var ois: ObjectInputStream = null
    var model: KMeansModel = null
    try {
        bais = new ByteArrayInputStream(bytesArray)
        ois = new ObjectInputStream(bais)
        model = ois.readObject().asInstanceOf[KMeansModel]
    } catch {
        case ex: Exception => {
            println("deserialize fail ...")
            ex.printStackTrace()
        }
    } finally {
        ois.close()
        bais.close()
    }
    model
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值