排序模型-FTRL

排序模型进阶-FTRL

1 问题

在实际项目的时候,经常会遇到训练数据非常大导致一些算法实际上不能操作的问题。比如在推荐行业中,因为DSP的请求数据量特别大,一个星期的数据往往有上百G,这种级别的数据在训练的时候,直接套用一些算法框架是没办法训练的,基本上在特征工程的阶段就一筹莫展。通常采用采样、截断的方式获取更小的数据集,或者使用大数据集群的方式进行训练,但是这两种方式在作者看来目前存在两个问题:

  • 采样数据或者截断数据的方式,非常的依赖前期的数据分析以及经验。
  • 大数据集群的方式,目前spark原生支持的机器学习模型比较少;使用第三方的算法模型的话,需要spark集群的2.3以上;而且spark训练出来的模型往往比较复杂,实际线上运行的时候,对内存以及QPS的压力比较大。
    在这里插入图片描述

2 在线优化算法-Online-learning

  1. 模型更新周期慢,不能有效反映线上的变化,最快小时级别,一 般是天级别甚至周级别。

  2. 模型参数少,预测的效果差;模型参数多线上predict的时候需要内存大,QPS无法保证。

  3. 对1采用On-line-learning的算法。

  4. 对2采用一些优化的方法,在保证精度的前提下,尽量获取稀疏解,从而降低模型参数的数量。
    比较出名的在线最优化的方法有:

TG(Truncated Gradient)
FOBOS(Forward-Backward Splitting)
RDA(Regularized Dual Averaging)
FTRL(Follow the Regularized Leader)
SGD算法是常用的online learning算法,它能学习出不错的模型,但学出的模型不是稀疏的。为此,学术界和工业界都在研究这样一种online learning算法,它能学习出有效的且稀疏的模型

2 FTRL

一种获得稀疏模型的优化方法
算法原理:http://vividfree.github.io/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/2015/12/05/understanding-FTRL-algorithm

3 离线数据训练FTRL模型

目的:通过离线TFRecords样本数据,训练FTRL模型
步骤:
1、构建模型
2、构建TFRecords的输入数据
3、train训练以及预测测试
完整代码:

import tensorflow as tf
from tensorflow.python import keras


class LrWithFtrl(object):
    """LR以FTRL方式优化
    """
    def __init__(self):
        self.model = keras.Sequential([
            keras.layers.Dense(1, activation='sigmoid', input_shape=(121,))
        ])

    @staticmethod
    def read_ctr_records():
        # 定义转换函数,输入时序列化的
        def parse_tfrecords_function(example_proto):
            features = {
                "label": tf.FixedLenFeature([], tf.int64),
                "feature": tf.FixedLenFeature([], tf.string)
            }
            parsed_features = tf.parse_single_example(example_proto, features)

            feature = tf.decode_raw(parsed_features['feature'], tf.float64)
            feature = tf.reshape(tf.cast(feature, tf.float32), [1, 121])
            label = tf.reshape(tf.cast(parsed_features['label'], tf.float32), [1, 1])
            return feature, label

        dataset = tf.data.TFRecordDataset(["./train_ctr_201904.tfrecords"])
        dataset = dataset.map(parse_tfrecords_function)
        dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.repeat(10000)
        return dataset

    def train(self, dataset):

        self.model.compile(optimizer=tf.train.FtrlOptimizer(0.03, l1_regularization_strength=0.01,
                                                            l2_regularization_strength=0.01),
                           loss='binary_crossentropy',
                           metrics=['binary_accuracy'])
        self.model.fit(dataset, steps_per_epoch=10000, epochs=10)
        self.model.summary()
        self.model.save_weights('./ckpt/ctr_lr_ftrl.h5')

    def predict(self, inputs):
        """预测
        :return:
        """
        # 首先加载模型
        self.model.load_weights('/root/toutiao_project/reco_sys/offline/models/ckpt/ctr_lr_ftrl.h5')
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            sess.run(init)
            predictions = self.model.predict(sess.run(inputs))
        return predictions


if __name__ == '__main__':
    lwf = LrWithFtrl()
    dataset = lwf.read_ctr_records()
    inputs, labels = dataset.make_one_shot_iterator().get_next()
    print(inputs, labels)
    lwf.predict(inputs)

在线预测

def lrftrl_sort_service(reco_set, temp, hbu):
    """
    排序返回推荐文章
    :param reco_set:召回合并过滤后的结果
    :param temp: 参数
    :param hbu: Hbase工具
    :return:
    """
    print(344565)
    # 排序
    # 1、读取用户特征中心特征
    try:
        user_feature = eval(hbu.get_table_row('ctr_feature_user',
                                              '{}'.format(temp.user_id).encode(),
                                              'channel:{}'.format(temp.channel_id).encode()))
        logger.info("{} INFO get user user_id:{} channel:{} profile data".format(
            datetime.now().strftime('%Y-%m-%d %H:%M:%S'), temp.user_id, temp.channel_id))
    except Exception as e:
        user_feature = []
        logger.info("{} WARN get user user_id:{} channel:{} profile data failed".format(
            datetime.now().strftime('%Y-%m-%d %H:%M:%S'), temp.user_id, temp.channel_id))
    reco_set = [13295, 44020, 14335, 4402, 2, 14839, 44024, 18427, 43997, 17375]
    if user_feature and reco_set:
        # 2、读取文章特征中心特征
        result = []
        for article_id in reco_set:
            try:
                article_feature = eval(hbu.get_table_row('ctr_feature_article',
                                                         '{}'.format(article_id).encode(),
                                                         'article:{}'.format(article_id).encode()))
            except Exception as e:
                article_feature = []

            if not article_feature:
                article_feature = [0.0] * 111
            f = []
            f.extend(user_feature)
            f.extend(article_feature)

            result.append(f)

        # 4、预测并进行排序是筛选
        arr = np.array(result)

        # 加载逻辑回归模型
        lwf = LrWithFtrl()
        print(tf.convert_to_tensor(np.reshape(arr, [len(reco_set), 121])))
        predictions = lwf.predict(tf.constant(arr))

        df = pd.DataFrame(np.concatenate((np.array(reco_set).reshape(len(reco_set), 1), predictions),
                                          axis=1),
                          columns=['article_id', 'prob'])

        df_sort = df.sort_values(by=['prob'], ascending=True)

        # 排序后,只将排名在前100个文章ID返回给用户推荐
        if len(df_sort) > 100:
            reco_set = list(df_sort.iloc[:100, 0])
        reco_set = list(df_sort.iloc[:, 0])

    return reco_set

4 TensorFlow FTRL 读取训练

训练数据说明

  • 原始特征用MurmurHash3的方式,将特征id隐射到(Long.MinValue, Long.MaxValue)范围

  • 保存成One-Hot的数据格式
    算法参数

  • lambda1:L1正则系数,参考值:10 ~ 15

  • lambda2:L2正则系数,参考值:10 ~ 15

  • alpha:FTRL参数,参考值:0.1

  • beta:FTRL参数,参考值:1.0

  • batchSize: mini-batch的大小,参考值:10000

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值