1、过程介绍
(1)读取需要做inference的数据
(2)将每个样本的特征做处理,处理成模型训练时使用的形式,这里涉及到id型特征的id映射。
(3)使用模型做标签预测。
2、代码如下
import os
import platform
import math
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.python.ops import array_ops
plat = platform.system().lower()
tf.app.flags.DEFINE_string("tables", "", "Tables info")
tf.app.flags.DEFINE_string("outputs", "", "predict result output table")
tf.app.flags.DEFINE_string("checkpointDir", '', "Model checkpoint dir")
FLAGS = tf.app.flags.FLAGS
save_label = ["label1", "label2", "label3"]
domain = ""
def load_mapping_local(mapping_table):
import xlrd
workbook = xlrd.open_workbook(mapping_table)
sheet = workbook.sheet_by_index(0)
# 获取映射表
art_to_id = dict()
lang_to_id = dict()
release_to_id = dict()
duration_mean = 0
duration_std = 0
for index in range(1, sheet.nrows):
line = sheet.row_values(index)
if line[0] == "artist":
art_to_id[line[1]] = int(line[2])
elif line[0] == "language":
lang_to_id[line[1]] = int(line[2])
elif line[0] == "release":
release_to_id[line[1]] = int(line[2])
else:
if line[1] == "mean":
duration_mean = float(line[2])
else:
duration_std = float(line[2])
return art_to_id, lang_to_id, release_to_id, duration_mean, duration_std
def load_train_local(data_table):
import xlrd
workbook = xlrd.open_workbook(data_table)
sheet = workbook.sheet_by_index(0)
all_records = []
for index in range(1, sheet.nrows):
line = sheet.row_values(index)
tmp = [line[0], line[1], line[4], line[7], line[5], line[8], line[6], line[-2], line[9], line[2]]
all_records.append(tmp)
return all_records
def load_data():
"""读取需要做预测的数据,以及mapping文件"""
if plat == "windows":
mapping_table, data_table = "./data/id_map.xlsx", "./data/local_sample.xlsx"
art_to_id, lang_to_id, release_to_id, duration_mean, duration_std = load_mapping_local(mapping_table)
org_data = load_train_local(data_table)
# 生成训练数据
predict_data = []
for line in org_data:
artisid, language, release, duration, embedding = line[2], line[8], line[7], line[6], line[9]
art_fea = art_to_id.get(artisid, 0)
language_fea = lang_to_id[language]
duration_fea = (float(duration) - duration_mean) / duration_std
try:
release_year = release.split("-")[0]
if release_year == "0000" or int(release_year) > 2022:
release_id = "0000"
elif int(release_year) < 1970:
release_id = "1111"
else:
release_id = str(int(release_year) - int(release_year) % 5)
except:
release_id = "0000"
release_fea = release_to_id[release_id]
tmp = embedding.split(",") + [art_fea, language_fea, release_fea, duration_fea]
predict_data.append(tmp)
predict_data = np.array(predict_data).astype(np.float32)
return org_data, predict_data
def main_func():
org_data, predict_data = load_data()
print("load data……")
if plat == "windows":
meta_path = "./model/{}_mll.meta".format(domain)
ckpt_path = "./model/{}_mll".format(domain)
else:
meta_path = os.path.join(FLAGS.checkpointDir, '{}_mll.meta'.format(domain))
ckpt_path = os.path.join(FLAGS.checkpointDir, '{}_mll'.format(domain))
saver = tf.train.import_meta_graph(meta_path)
graph = tf.get_default_graph()
batch_samples = graph.get_operation_by_name('batch_samples').outputs[0]
music_embedding = graph.get_operation_by_name('music_embedding').outputs[0]
art_id = graph.get_operation_by_name('art_id').outputs[0]
lang_id = graph.get_operation_by_name('lang_id').outputs[0]
release_id = graph.get_operation_by_name('release_id').outputs[0]
logits = tf.get_collection('pred_network')[0]
batchsize = 256
pred_steps = math.ceil(len(predict_data) / batchsize)
init_op = tf.global_variables_initializer()
prediction_y = []
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, ckpt_path)
for step in range(1, pred_steps + 1):
start = (step - 1) * batchsize
end = step * batchsize
batch_x = predict_data[start:end, :]
music_x, art_x, lang_x, release_x, dur_x = batch_x[:, :-4], batch_x[:, -4], batch_x[:, -3], \
batch_x[:, -2], batch_x[:, -1].reshape(-1, 1)
music_x = np.hstack((music_x, dur_x))
art_x = art_x.reshape([len(art_x), 1]).astype(int)
lang_x = lang_x.reshape([len(lang_x), 1]).astype(int)
release_x = release_x.reshape([len(release_x), 1]).astype(int)
batch_prediction = sess.run([logits], feed_dict={music_embedding: music_x,
batch_samples: [len(art_x)],
art_id: art_x,
lang_id: lang_x,
release_id: release_x})
prediction_y += batch_prediction[0].tolist()
if step % 100 == 0:
print("predict: ", step * batchsize)
prediction_y = tf.sigmoid(np.array(prediction_y)).eval()
pred_y = tf.cast(tf.greater(prediction_y, 0.5), tf.float32).eval()
write_recs = []
for i, pred in enumerate(pred_y.tolist()):
pred_tag = [save_label[index] for index, pred_flag in enumerate(pred) if pred_flag > 0]
if len(pred_tag) == 0:
continue
write_recs.append([org_data[i][0], ",".join(pred_tag)])
if __name__ == '__main__':
main_func()