PyFlink 加载Keras 模型进行参数预测

 PyFlink 记载Keras 模型,进行实时性参数预测

# -*- coding: utf-8 -*
import logging
import os

from pyflink.datastream import StreamExecutionEnvironment
from pyflink.table import StreamTableEnvironment, EnvironmentSettings, DataTypes
from pyflink.table.udf import ScalarFunction, udf

"""
加载Mysql中的重连数据进行模型训练、验证
"""

# ########################### 初始化流处理环境 ###########################

# 创建 Blink 流处理环境,注意此处需要指定 StreamExecutionEnvironment,否则无法导入 java 函数
env = StreamExecutionEnvironment.get_execution_environment()
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().use_blink_planner().build()
env.set_parallelism(1)
t_env = StreamTableEnvironment.create(env, environment_settings=env_settings)
# 设置该参数以使用 UDF
t_env.get_config().get_configuration().set_boolean("python.fn-execution.4memory.managed", True)
t_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", "80m")
# ########################### 指定 jar 依赖 ###########################

# dir_kafka_sql_connect = os.path.join(os.path.abspath(os.path.dirname(__file__)),
#                                      'flink-sql-connector-kafka_2.11-1.11.2.jar')
# t_env.get_config().get_configuration().set_string("pipeline.jars", 'file:///' + dir_kafka_sql_connect)


filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ram.log')
logging.basicConfig(filename=filename, level=logging.INFO)


# ########################### 注册 UDF ###########################

class myKerasMLP(ScalarFunction):
    def __init__(self):
        print("Model __init__方法")
        # 加载模型
        self.model_name = 'Parameter_Predict_Net'
        self.weights = 'Parameter_Predict_Net_weights'
        self.redis_params = dict(host='localhost', password='123456', port=6379, db=1)
        self.model = None
        # y 的定义域
        self.classes = list(range(10))
        # 自定义的 4 类指标,用于评估模型和样本,指标值将暴露给外部系统以便于实时监控模型的状况
        self.metric_counter = None  # 从作业开始至今的所有样本数量
        self.metric_predict_acc = 0  # 模型预测的准确率(用过去 10 条样本来评估)
        self.metric_distribution_y = None  # 标签 y 的分布
        self.metric_total_10_sec = None  # 过去 10 秒内训练过的样本数量
        self.metric_right_10_sec = None  # 过去 10 秒内的预测正确的样本数

    def open(self, function_context):
        """
        访问指标系统,并注册指标,以便于在 webui (localhost:8081) 实时查看算法的运行情况。
        :param function_context:
        :return:
        """
        if self.model:
            print("模型已加载..")
        else:
            self.model = self.load_model()
        self.model.summary()

        print("Model open方法")
        # 访问指标系统,并定义 Metric Group 名称为 online_ml 以便于在 webui 查找
        # Metric Group + Metric Name 是 Metric 的唯一标识
        metric_group = function_context.get_metric_group().add_group("online_ml")

        # 目前 PyFlink 1.11.2 版本支持 4 种指标:计数器 Counters,量表 Gauges,分布 Distribution 和仪表 Meters 。
        # 目前这些指标都只能是整数

        # 1、计数器 Counter,用于计算某个东西出现的次数,可以通过 inc()/inc(n:int) 或 dec()/dec(n:int) 来增加或减少值
        self.metric_counter = metric_group.counter('sample_count')  # 训练过的样本数量

        # 2、量表 Gauge,用于根据业务计算指标,可以比较灵活地使用
        # 目前 pyflink 只支持 Gauge 为整数值
        metric_group.gauge("prediction_acc", lambda: int(self.metric_predict_acc * 100))

        # 3、分布 Distribution,用于报告某个值的分布信息(总和,计数,最小,最大和平均值)的指标,可以通过 update(n: int) 来更新值
        # 目前 pyflink 只支持 Distribution 为整数值
        self.metric_distribution_y = metric_group.distribution("metric_distribution_y")

        # 4、仪表 Meters,用于汇报平均吞吐量,可以通过 mark_event(n: int) 函数来更新事件数。
        # 统计过去 10 秒内的样本量、预测正确的样本量
        self.metric_total_10_sec = metric_group.meter("total_10_sec", time_span_in_seconds=10)
        self.metric_right_10_sec = metric_group.meter("right_10_sec", time_span_in_seconds=10)

    def eval(self, *args):
        """
         模型预测
        :param args: 参数集合
        :return:
        """
        from sklearn.preprocessing import StandardScaler
        import numpy as np
        import redis
        import pickle
        redis = redis.StrictRedis(**self.redis_params)
        # 加载训练好的StandardScaler,应用于单条记录的归一化
        x_sc = pickle.loads(redis.get("x_sc"))
        y_sc = pickle.loads(redis.get("y_sc"))
        # 拼接参数
        a = []
        for u in args:
            a.append(u)
        # shape :(7,1)
        print("a:", np.array(a))
        # shape :(1,7)
        b = np.transpose(np.array(a).reshape(-1, 1))
        # 数据归一化
        data = x_sc.transform(b)
        y_pred = self.model.predict(data)
        # 反归一化
        trueY = y_sc.inverse_transform(y_pred)
        # 返回预测结果
        return trueY[0][0], trueY[0][1]

    def load_model(self):
        """
        加载模型,如果 redis 里存在模型,则优先从 redis 加载,否则初始化一个新模型
        :return:
        """
        import redis
        import pickle
        import logging
        from keras.models import model_from_json
        from redis import StrictRedis
        logging.info('载入模型!')
        redis = redis.StrictRedis(**self.redis_params)
        model = None

        try:
            # 从redis中获取模型、应用pickle.loads加载模型
            print(redis.get("NT_Parameter_Predict_Net"))
            model = model_from_json(redis.get("NT_Parameter_Predict_Net"))
            model.set_weights(pickle.loads(redis.get("NT_Parameter_Predict_Net_weights")))
            model.summary()
        except TypeError:
            logging.error('Redis 内没有指定名称的模型,请先训练模型保存至Redis')

        return model

##############################################
#      特征输入:
#      hotime, before_ta, before_rssi, after_ta, after_rssil, nb_tath, nb_rssith
#      训练输出:
#      nbrta nbrssithrd
#
##############################################
myKerasMLP = udf(myKerasMLP(), input_types=[DataTypes.STRING(), DataTypes.STRING(), DataTypes.STRING(), DataTypes.STRING(),
                                            DataTypes.STRING(), DataTypes.STRING(), DataTypes.STRING()],
                 result_type=DataTypes.ARRAY(DataTypes.FLOAT()))

print('UDF 模型加载完成!')
# t_env.create_temporary_system_function('train_and_predict', myKerasMLP)
t_env.register_function('train_and_predict', myKerasMLP)
print('UDF 注册成功')

# ########################### 创建源表(source) ###########################
# 使用 MySQL-CDC 连接器从 MySQL 的 binlog 里提取更改。
# 该连接器非官方连接器,写法请参照扩展阅读 2。

t_env.execute_sql("""
    CREATE TABLE source (
        hotime STRING ,  
        before_ta STRING ,
        before_rssi STRING ,
        after_ta STRING ,
        after_rssil STRING ,
        nb_tath STRING ,
        nb_rssith STRING ,
        nbr_rssi STRING ,         
        nbr_ta STRING         
    ) WITH (
        'connector' = 'jdbc',
        'url' = 'jdbc:mysql://localhost:3306/hadoop',
        'table-name' = 'nt_data',
        'username' = 'root',
        'password' = '123456'
    )
""")


t_env.execute_sql("""
CREATE TABLE print_table (
        hotime STRING ,  
        before_ta STRING ,
        before_rssi STRING ,
        after_ta STRING ,
        after_rssil STRING ,
        nb_tath STRING ,
        nb_rssith STRING ,
        predict ARRAY<FLOAT >
) WITH (
 'connector' = 'print'
)
""")

# ###########################  ###########################
t_env.sql_query("""
SELECT
    hotime ,  
    before_ta ,
    before_rssi ,
    after_ta ,
    after_rssil ,
    nb_tath ,
    nb_rssith ,
    train_and_predict(hotime, before_ta, before_rssi, after_ta, after_rssil, nb_tath, nb_rssith) predict
FROM
    source
""").insert_into("print_table")

t_env.execute('NT重连预测参数')

参考:

https://blog.csdn.net/weixin_39334709/article/details/109893599?spm=1001.2014.3001.5502

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值