利用tensorflow estimator API实现双塔推荐算法

本文展示了如何利用tensorflow estimator API实现双塔推荐算法的分布式训练。提供了包括特征处理、模型定义、数据输入和主函数在内的完整代码模板,适用于自定义数据和模型结构。在config.py中配置参数,通过调整run_on_cluster参数可选择单机或分布式运行。
摘要由CSDN通过智能技术生成

本文完整代码见: https://github.com/cdj0311/two_tower_recommendation_system

Tensorflow estimator实现分布式训练很简单,只需要将数据进行相应的切分丢给模型就可以很方便的完成分布式训练了。以下代码是一个完整的推荐算法模板,可根据自己的需要修改数据读取和模型结构部分,tensorflow==1.13.1。

1. 特征处理部分,feature_processing.py

#coding:utf-8
import tensorflow as tf
from tensorflow import feature_column as fc
import config

FLAGS = config.FLAGS

class FeatureConfig(object):
    def __init__(self):
        self.user_columns = dict()
        self.item_columns = dict()
        self.feature_spec = dict()

    def create_features_columns(self):
        # 向量类特征
        user_vector = fc.numeric_column(key="user_vector", shape=(128,), default_value=[0.0] * 128, dtype=tf.float32)
        item_vector = fc.numeric_column(key="item_vector", shape=(128,), default_value=[0.0] * 128, dtype=tf.float32)
        
        # 分桶类特征
        age = fc.numeric_column(key="age", shape=(1,), default_value=[0], dtype=tf.int64)
        age = fc.bucketized_column(input_fc, boundaries=[0,10,20,30,40,50,60,70,80])
        age = fc.embedding_column(age, dimension=32, combiner='mean')
        
        # 分类特征
        city = fc.categorical_column_with_identity(key="city", num_buckets=1000, default_value=0)
        city = fc.embedding_column(city, dimension=32, combiner='mean')
        
        # hash特征
        device_id = fc.categorical_column_with_hash_bucket(key="device_id", 
                    hash_bucket_size=1000000, dtype=tf.int64)
        device_id = fc.embedding_column(device_id, dimension=32, combiner='mean')

        item_id = fc.categorical_column_with_hash_bucket(key="item_id", 
                    hash_bucket_size=1000000, dtype=tf.int64)
        item_id = fc.embedding_column(item_id, dimension=32, combiner='mean')
        
        self.user_columns["user_vector"] = user_vector
        self.user_columns["age"] = age
        self.user_columns["city"] = city
        self.user_columns["device_id"] = device_id
        self.item_columns["item_vector"] = item_vector
        self.ite
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值