多标签分类inference——tensorflow

1、过程介绍

(1)读取需要做inference的数据

(2)将每个样本的特征做处理,处理成模型训练时使用的形式,这里涉及到id型特征的id映射。

(3)使用模型做标签预测。

2、代码如下

import os
import platform
import math
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.python.ops import array_ops

plat = platform.system().lower()

tf.app.flags.DEFINE_string("tables", "", "Tables info")
tf.app.flags.DEFINE_string("outputs", "", "predict result output table")
tf.app.flags.DEFINE_string("checkpointDir", '', "Model checkpoint dir")
FLAGS = tf.app.flags.FLAGS

save_label = ["label1", "label2", "label3"]
domain = ""


def load_mapping_local(mapping_table):
    import xlrd
    workbook = xlrd.open_workbook(mapping_table)
    sheet = workbook.sheet_by_index(0)

    # 获取映射表
    art_to_id = dict()
    lang_to_id = dict()
    release_to_id = dict()
    duration_mean = 0
    duration_std = 0

    for index in range(1, sheet.nrows):
        line = sheet.row_values(index)
        if line[0] == "artist":
            art_to_id[line[1]] = int(line[2])
        elif line[0] == "language":
            lang_to_id[line[1]] = int(line[2])
        elif line[0] == "release":
            release_to_id[line[1]] = int(line[2])
        else:
            if line[1] == "mean":
                duration_mean = float(line[2])
            else:
                duration_std = float(line[2])

    return art_to_id, lang_to_id, release_to_id, duration_mean, duration_std


def load_train_local(data_table):
    import xlrd
    workbook = xlrd.open_workbook(data_table)
    sheet = workbook.sheet_by_index(0)

    all_records = []
    for index in range(1, sheet.nrows):
        line = sheet.row_values(index)
        tmp = [line[0], line[1], line[4], line[7], line[5], line[8], line[6], line[-2], line[9], line[2]]
        all_records.append(tmp)

    return all_records


def load_data():
    """读取需要做预测的数据,以及mapping文件"""
    if plat == "windows":
        mapping_table, data_table = "./data/id_map.xlsx", "./data/local_sample.xlsx"
        art_to_id, lang_to_id, release_to_id, duration_mean, duration_std = load_mapping_local(mapping_table)
        org_data = load_train_local(data_table)


    # 生成训练数据
    predict_data = []
    for line in org_data:
        artisid, language, release, duration, embedding = line[2], line[8], line[7], line[6], line[9]
        art_fea = art_to_id.get(artisid, 0)
        language_fea = lang_to_id[language]
        duration_fea = (float(duration) - duration_mean) / duration_std

        try:
            release_year = release.split("-")[0]
            if release_year == "0000" or int(release_year) > 2022:
                release_id = "0000"
            elif int(release_year) < 1970:
                release_id = "1111"
            else:
                release_id = str(int(release_year) - int(release_year) % 5)
        except:
            release_id = "0000"

        release_fea = release_to_id[release_id]
        tmp = embedding.split(",") + [art_fea, language_fea, release_fea, duration_fea]
        predict_data.append(tmp)

    predict_data = np.array(predict_data).astype(np.float32)

    return org_data, predict_data


def main_func():
    org_data, predict_data = load_data()
    print("load data……")

    if plat == "windows":
        meta_path = "./model/{}_mll.meta".format(domain)
        ckpt_path = "./model/{}_mll".format(domain)
    else:
        meta_path = os.path.join(FLAGS.checkpointDir, '{}_mll.meta'.format(domain))
        ckpt_path = os.path.join(FLAGS.checkpointDir, '{}_mll'.format(domain))

    saver = tf.train.import_meta_graph(meta_path)
    graph = tf.get_default_graph()

    batch_samples = graph.get_operation_by_name('batch_samples').outputs[0]
    music_embedding = graph.get_operation_by_name('music_embedding').outputs[0]
    art_id = graph.get_operation_by_name('art_id').outputs[0]
    lang_id = graph.get_operation_by_name('lang_id').outputs[0]
    release_id = graph.get_operation_by_name('release_id').outputs[0]
    logits = tf.get_collection('pred_network')[0]

    batchsize = 256
    pred_steps = math.ceil(len(predict_data) / batchsize)

    init_op = tf.global_variables_initializer()
    prediction_y = []
    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, ckpt_path)

        for step in range(1, pred_steps + 1):
            start = (step - 1) * batchsize
            end = step * batchsize
            batch_x = predict_data[start:end, :]
            music_x, art_x, lang_x, release_x, dur_x = batch_x[:, :-4], batch_x[:, -4], batch_x[:, -3], \
                                                       batch_x[:, -2], batch_x[:, -1].reshape(-1, 1)
            music_x = np.hstack((music_x, dur_x))
            art_x = art_x.reshape([len(art_x), 1]).astype(int)
            lang_x = lang_x.reshape([len(lang_x), 1]).astype(int)
            release_x = release_x.reshape([len(release_x), 1]).astype(int)

            batch_prediction = sess.run([logits], feed_dict={music_embedding: music_x,
                                                             batch_samples: [len(art_x)],
                                                             art_id: art_x,
                                                             lang_id: lang_x,
                                                             release_id: release_x})
            prediction_y += batch_prediction[0].tolist()

            if step % 100 == 0:
                print("predict: ", step * batchsize)

        prediction_y = tf.sigmoid(np.array(prediction_y)).eval()
        pred_y = tf.cast(tf.greater(prediction_y, 0.5), tf.float32).eval()

    write_recs = []
    for i, pred in enumerate(pred_y.tolist()):
        pred_tag = [save_label[index] for index, pred_flag in enumerate(pred) if pred_flag > 0]
        if len(pred_tag) == 0:
            continue
        write_recs.append([org_data[i][0], ",".join(pred_tag)])


if __name__ == '__main__':
    main_func()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值