【tensorflow】estimator使用

环境:tf 1.10.0

  1. 制作TFRecordDataset。
  2. 编写input_fn,实现对TFRecordDataset的解析。传入参数是数据集路径名,返回结果是feature(字典)和label。解析通过parse_fn实现。parse_fn根据TFRecordDataset格式实现。
    1. 编写parse_fn,传入参数是数据集的一行,返回结果是一行解析出来的多个特征。
  3. 编写Estimator,传入model_fn,硬参数params(字典),和模型存储地址。
    1. 编写model_fn,实现模型结构。传入参数是feature,label(两者从input_fn中可以得到),mode(调用时会提供),params直接来自Estimator中的硬参数。返回结果是不同mode的EstimatorSpec。

 

具体示例:

tag_vector.py

import tensorflow as tf
import numpy as np
from my_model.max_pool_model.Config import config
from my_model.max_pool_model.make_dataset2 import read_pre_emb

pre_trained_emb = read_pre_emb(config["pretrained_vector_path"])
tf.logging.set_verbosity(tf.logging.INFO)

def model_fn(features, labels, mode, params):
    musk = features["musk"]  # batch_size * length
    musk_vec = tf.expand_dims(musk, -1)
    musk_vec = tf.tile(musk_vec, [1, 1, params["vector_dim"]])
    #musk_vec_out = tf.Print(musk_vec, [tf.shape(musk), tf.shape(musk_vec)], message="musk, musk_vec")

    emb = tf.Variable(initial_value=pre_trained_emb, name="emb", trainable=True)
    tag_id = features["cate_id"] # batch_size * length
    tag_vec = tf.nn.embedding_lookup(emb, tag_id) # (batch_size, length, 56)
    #tag_vec_out = tf.Print(tag_vec, [tf.shape(tag_id), tf.shape(tag_vec)], message="tag_id, tag_vec")
    tag_vec = tf.multiply(musk_vec, tag_vec)

    max = tf.reduce_max(tag_vec, axis=2)
    max_expand = tf.expand_dims(max, -1)
    #max_vec = tf.tile(max_expand, [params["batch_size"], params["max_tag_len"], params["vector_dim"]])
    max_vec = tf.tile(max_expand, [1, 1, params["vector_dim"]])
    #max_vec_out = tf.Print(max_vec, [tf.shape(max), tf.shape(max_expand), tf.shape(max_vec)], message="max, max_expand, max_vec")

    bucket_mask = tf.equal(tag_vec, max_vec) # (batch_size, length, 56)
    bucket_mask = tf.cast(bucket_mask, tf.float32)
    bucket_mask = tf.transpose(bucket_mask, perm=[0, 2, 1]) # (batch_size, 56, length)
    #bucket_mask_out = tf.Print(bucket_mask, [tf.shape(bucket_mask)], message="bucket_mask")

    profile_matrix = tf.matmul(bucket_mask, tag_vec) # (batch_size, 56(bucket), 56(prob))
    profile_matrix = tf.nn.l2_normalize(profile_matrix, axis=2)
    #profile_matrix_out = tf.Print(profile_matrix, [tf.shape(profile_matrix)], message="profile_matrix")

    title_vec = features["title_vec"]  # batch_size * 56
    title_vec = tf.nn.l2_normalize(title_vec, axis=1)
    title_expand = tf.expand_dims(title_vec, 1)
    title_repeat = tf.tile(title_expand, [1, params["vector_dim"], 1])
    #title_repeat_out = tf.Print(title_repeat, [tf.shape(title_vec), tf.shape(title_expand), tf.shape(title_repeat)], message="title_vec, title_expand, title_repeat")

    score_mid = tf.multiply(title_repeat, profile_matrix) # batch_size, 56, 56
    score_sum = tf.reduce_sum(score_mid, axis=1) # batch_size, 56
    score = tf.reshape(score_sum, [params["batch_size"], params["vector_dim"], 1, 1])
    #score_out = tf.Print(score, [tf.shape(score_mid), tf.shape(score_sum), tf.shape(score)], message="score_mid, score_sum, score")

    max_cos_similarity_mid = tf.nn.max_pool(score, [1,params["vector_dim"],1,1], [1,1,1,1], 'VALID')
    max_cos_similarity = tf.reshape(max_cos_similarity_mid, [params["batch_size"]])
    #max_cos_sim = tf.Print(max_cos_similarity, [tf.shape(max_cos_similarity_mid), tf.shape(max_cos_similarity)], message="max_cos_similarity_mid, max_cos_similarity")
    if mode == tf.estimator.ModeKeys.PREDICT:
        spec = tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=max_cos_similarity)
    else:
        loss = tf.losses.mean_squared_error(labels, max_cos_similarity)

        optimizer = tf.train.AdamOptimizer(learning_rate=params["learning_rate"])
        train_op = optimizer.minimize(
            loss=loss, global_step=tf.train.get_global_step())
        metrics = \
            {
                "accuracy": tf.metrics.mean_squared_error(labels, max_cos_similarity)
            }

        logging_hook = tf.train.LoggingTensorHook({"loss": loss}, every_n_iter=10)
        spec = tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=metrics,
            training_hooks = [logging_hook])
    return spec

params = config

model = tf.estimator.Estimator(model_fn=model_fn,
                               params=params,
                               model_dir="./ckpt/")

def parse_fn(example_proto):
    # features = {"title_vec": tf.VarLenFeature(tf.float32),
    #             "cate_id": tf.VarLenFeature(tf.int64),
    #             "musk": tf.VarLenFeature(tf.int64),
    #             "label": tf.VarLenFeature(tf.float32)}
    features = {"title_vec": tf.FixedLenFeature((), tf.string),
                "cate_id": tf.FixedLenFeature((), tf.string),
                "musk": tf.FixedLenFeature((), tf.string),
                "label": tf.FixedLenFeature((), tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return tf.decode_raw(parsed_features['title_vec'], tf.float32), tf.decode_raw(parsed_features['cate_id'], tf.int64), tf.decode_raw(parsed_features['musk'], tf.float32), tf.cast(parsed_features['label'], tf.float32)


def tf_input_fn(dataset_path):
    dataset = tf.data.TFRecordDataset(dataset_path)
    dataset = dataset.map(parse_fn)
    dataset = dataset.shuffle(buffer_size=config["buffer_size"])
    dataset = dataset.repeat(config["repeat_num"])
    dataset = dataset.batch(config["batch_size"], drop_remainder=True)
    iterator = dataset.make_one_shot_iterator()
    title_vec, cate_id, musk, label = iterator.get_next()
    features = {"title_vec": title_vec,
                "cate_id": cate_id,
                "musk": musk}
    return features, label

input_fn = lambda: tf_input_fn(config["sample_path"])

if __name__ == "__main__":
    model.train(input_fn, steps=config["step"])

Config.py

config = {
    "max_tag_len": 15,
    "batch_size": 32,
    "repeat_num": 5,
    "learning_rate": 1e-3,
    "vector_dim": 56,
    "buffer_size": 100000,
    "sample_path": "E:/code/deepFM/my_model/max_pool_model/data/tag_vector.trainset",
    "pretrained_vector_path": "E:/code/deepFM/my_model/max_pool_model/data/cate.vector",
    "step": 20000
}

make_dataset.py

import tensorflow as tf
from my_model.max_pool_model.Config import config
import numpy as np

max_tag_len = config["max_tag_len"]

def gather_cateset(catefile_prefix, date_list):
    cate_vec_dict = {}
    for cur_date in date_list:
        cur_catefile = catefile_prefix +"."+ cur_date
        with open(cur_catefile, "r", encoding="utf8") as fin:
            for line in fin:
                cate, vector = line.strip().split("\t")
                # vector = [float(v) for v in vector.split(",")]
                cate_vec_dict[cate] = vector
    with open(catefile_prefix+".vocabulary", "w", encoding="utf8") as vocab_out, \
            open(catefile_prefix+".vector", "w", encoding="utf8") as vector_out:
        for k in cate_vec_dict.keys():
            vocab_out.write(k+"\n")
            vector_out.write(cate_vec_dict[k]+"\n")
    return catefile_prefix+".vocabulary", catefile_prefix+".vector"

def cate2id(sample_prefix, date_list, vocab_path):
    vocab_dict = {}
    with open(vocab_path, "r", encoding="utf8") as fvocab:
        for i, line in enumerate(fvocab):
            line = line.strip()
            vocab_dict[line] = i

    vec_list = []
    pos_list = []
    label_list = []
    tag_list = []
    musk_list = []
    for cur_date in date_list:
        with open(sample_prefix+"."+cur_date, "r", encoding="utf8") as fsample:
            """
            20210117	1431508111	商务服务其他:2.0,普通住宅交易:2.0,求职招聘:2.0,政治:1.2,本地资讯:1.0,移动互联网服务:0.7,书籍:0.7,聊天交友:0.5,成人教育:0.4,安卓手机:0.2,烹饪/菜谱:0.2,汽车(按动力):0.2||40||2||5||5||	0.0	浙江:小伙近日邂逅断联许久的中学同学来叙旧,后面发生的事却令其始料不及	7.749549695290625e-05,3.60935773642268e-05,0.001010984880849719,2.3608849915035535e-06,7.764240876895201e-07,2.846607458195649e-05,2.4162945919670165e-05,5.444126145448536e-06,9.779247920960188e-05,0.04159586876630783,2.9859456844860688e-05,1.6194935597013682e-05,3.4351709473412484e-05,3.2448726415168494e-05,4.423070640768856e-05,0.0038788700476288795,5.102722752781119e-06,0.9349178075790405,5.321004209690727e-05,4.378621088108048e-05,9.578113349562045e-06,1.4045025636733044e-05,0.005284507293254137,0.000136196002131328,0.001839602249674499,1.526522646599915e-05,0.0009756681974977255,1.3147743629815523e-05,0.0002255358558613807,0.0028283719439059496,4.594998245011084e-06,1.2239048601259128e-06,0.0007401834009215236,1.1609517059696373e-05,0.000715986592695117,0.0003761060652323067,0.00018034402455668896,1.5868859918555245e-05,0.00022613306646235287,0.00014358671614900231,2.0957126253051683e-05,3.710948658408597e-05,0.001000268617644906,1.2040588444506284e-05,8.790128777036443e-05,0.00044994920608587563,1.252541369467508e-05,3.716978881129762e-06,6.246905832085758e-05,1.1683483762681135e-06,1.0582159575278638e-06,5.6346008932450786e-05,9.603233775123954e-05,1.3426918485492934e-05,0.0001948853605426848,0.0022573566529899836
            """
            for line in fsample:
                line = line.strip().split()
                if len(line) != 6:
                    continue
                ftime = line[0]
                uin = line[1]
                gather_feature = line[2]
                if len(gather_feature.split("||")) != 6:
                    continue
                label = float(line[3])
                if label < 1.0:
                    continue
                title = line[4]
                vector = line[5]
                vector = [float(v) for v in vector.split(",")]

                # ["user_category", "age", "gender", "grade", "city_level"]
                tag, age, gender, grade, city, _ = gather_feature.split("||")
                tag =[t.split(":")[0] for t in tag.split(",")]
                tag_id = [vocab_dict[t] for t in tag]
                if len(tag_id) >= max_tag_len:
                    tag_id = tag_id[:max_tag_len]
                    musk = [1 for i in range(max_tag_len)]
                else:
                    ori_len = len(tag_id)
                    tag_id = tag_id + [0] * (max_tag_len-ori_len)
                    musk = [1]*ori_len + [0] * (max_tag_len-ori_len)
                vec_list.append(vector)
                label_list.append(label)
                tag_list.append(tag_id)
                musk_list.append(musk)

    return vec_list, label_list, tag_list, musk_list

def save_tfrecords(sample_prefix, date_list, vocab_path, save_path):
    vec_list, label_list, tag_list, musk_list = cate2id(sample_prefix, date_list,
                                                        vocab_path)  # 将sample文件中的cate标签转为cate_id
    print("dataset size:{}".format(len(label_list)))
    vec_list = np.asarray(vec_list)
    tag_list = np.asarray(tag_list)
    musk_list = np.asarray(musk_list)
    sample = list(zip(vec_list, tag_list, musk_list, label_list))
    with tf.python_io.TFRecordWriter(save_path) as writer:
        for i in range(len(sample)):
            record = tf.train.Features(
                feature = {
                    "title_vec": tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample[i][0].astype(np.float32).tostring()])),
                    "cate_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample[i][1].astype(np.int64).tostring()])),
                    "musk": tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample[i][2].astype(np.float32).tostring()])),
                    "label": tf.train.Feature(float_list=tf.train.FloatList(value=[sample[i][3]]))
                    # "title_vec": tf.train.Feature(float_list=tf.train.FloatList(value=sample[i][0])),
                    # "cate_id": tf.train.Feature(int64_list=tf.train.Int64List(value=sample[i][1])),
                    # "musk": tf.train.Feature(int64_list=tf.train.Int64List(value=sample[i][2])),
                    # "label": tf.train.Feature(float_list=tf.train.FloatList(value=[sample[i][3]]))
                }
            )
            example = tf.train.Example(features = record)
            serialized = example.SerializeToString()
            writer.write(serialized)



def main():
    catefile_prefix = "./data/cate"
    sample_prefix = "./data/sample"
    date_list = ["20210116", "20210117"]
    tfrecord_path = "./data/tag_vector.trainset"
    vocab_path, vector_path = gather_cateset(catefile_prefix, date_list) # 将多天的cate vector汇总在一起
    save_tfrecords(sample_prefix, date_list, vocab_path, tfrecord_path)

def read_pre_emb(emb_path):
    emb = []
    with open(emb_path, "r", encoding="utf8") as fin:
        for line in fin:
            vec = line.strip().split(",")
            vec = list(map(lambda x: float(x), vec))
            emb.append(vec)
    emb = np.asarray(emb, dtype=np.float32)
    return emb


if __name__ == "__main__":
    main()

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值