tf.sequence_mask后做max操作或avg操作

以推荐商品为例:

(1)两个user;

(2)user在历史一个月中分别购买力2个和6个item;

(3)历史数据中 Input 的商品数量固定为3;

 

1. 定义数据

import numpy as np
import tensorflow as tf


# user真实的历史行为个数(item 个数)
hist_seq_num_list = np.array([[2], [6]])


# user历史行为中的item所对应的embedding
hist_user_embedding_list = np.array(
              [[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],           # 正常历史元素的embedding
               [11, 12, 13, 14, 15, 16, 17, 18, 19, 20],   # 正常历史元素的embedding
               [31, 32, 33, 34, 35, 36, 37, 38, 39, 40]],  # 需要mask
          
              [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],            # 正常历史元素的embedding
               [11, 12, 13, 14, 15, 16, 17, 18, 19, 20],   # 正常历史元素的embedding
               [31, 32, 33, 34, 35, 36, 37, 38, 39, 40]]   # 正常历史元素的embedding
             ])


print(hist_user_embedding_list.shape)
# (2, 3, 10)
# 2个user
# 每次input的历史序列的长度固定为3(历史 item 个数)
# item embedding的维度为10

 

2. mask 操作,具体函数用法请自行百度

# mask matrix
mask_list = tf.sequence_mask(hist_seq_num_list, 3, dtype=tf.float32)
with tf.compat.v1.Session() as sess:
    print("mask_list: ")
    print(sess.run(mask_list))
    print(mask_list.shape)
    print("-----"*10)


# transpose mask matrix
mask_transpose_list = tf.transpose(mask_list, (0, 2, 1))
with tf.compat.v1.Session() as sess:
    print("mask_transpose_list: ")
    print(sess.run(mask_transpose_list))
    print(mask_transpose_list.shape)
    print("-----"*10)
    

embedding_dim = hist_user_embedding_list.shape[-1]
print(embedding_dim)
print("-----"*10)


# 扩增 mask matrix
mask_tile_list = tf.tile(mask_transpose_list, [1, 1, embedding_dim])
with tf.compat.v1.Session() as sess:
    print("mask_tile_list: ")
    print(sess.run(mask_tile_list))
    print(mask_tile_list.shape)

 

3. max操作

# max element

hist = hist_user_embedding_list - (1-mask_tile_list) * 1e9
with tf.compat.v1.Session() as sess:
    print("hist: ")
    print(sess.run(hist))
    print("-----"*10)


hist_max = tf.reduce_max(hist, 1, keepdims=True)
with tf.compat.v1.Session() as sess:
    print("hist_max: ")
    print(sess.run(hist_max))

 

4. avg操作

# avg element

hist_sum = tf.reduce_sum(hist_user_embedding_list * mask_tile_list, 1, keepdims=True)
with tf.compat.v1.Session() as sess:
    print("hist_sum: ")
    print(sess.run(hist_sum))
    print("-----"*10)

    
user_hist_behavior_length = tf.reduce_sum(mask_list, axis=-1, keepdims=True)
with tf.compat.v1.Session() as sess:
    print("user_hist_behavior_length: ")
    print(sess.run(user_hist_behavior_length))
    print("-----"*10)

    
hist_mean = tf.divide(hist_sum, \
                      tf.cast(user_hist_behavior_length, tf.float32) + tf.constant(1e-8, tf.float32))
with tf.compat.v1.Session() as sess:
    print("hist_mean: ")
    print(sess.run(hist_mean))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值