项目是在github上看到的:https://github.com/dennybritz/cnn-text-classification-tf
是基于tensorflow进行的文本情感分析,由于代码使用规范可能较老,所以出现了一些问题,这里予以记录:
1.train.py文件
FLAGS.flag_values_dict()
#TensorFlow版本升级后,它就无情的抛弃了FLAGS._parse_flags()这种用法,改成了用FLAGS.flag_values_dict()
#FLAGS._parse_flags()
2.data_helper.py文件
lines = [clean_str(seperate_line(line)) for line in lines]
#lines = [clean_str(seperate_line(line.decode(‘utf-8’))) for line in lines]
#python3中已经没有decode了,所以将decode(‘utf-8’)删除
3.eval.py文件
这里遇到的问题很多,一个是出现的FutureWarning: Passing (type, 1) or ‘1type’ as a synonym of type is deprecate问题
该问题网上说是numpy版本问题,但改了以后依旧没有用处,正确做法如下:打开提示中的dtype.py文件将这些代码修改
原代码:
修改后的代码:
第二个问题是原来的代码输出的csv文件里是十六进制的比特字节,很难处理,在经过一系列操作以后修正了这个问题:
#! /usr/bin/env python
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
import word2vec_helpers
from text_cnn import TextCNN
import csv
import unicodecsv as uccv
# Parameters
# ==================================================
# Data Parameters
tf.flags.DEFINE_string("input_text_file", "./data/spam_100.utf8", "Test text data source to evaluate.")
tf.flags.DEFINE_string("input_label_file", "", "Label file for test text data source.")
# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
tf.flags.DEFINE_boolean("eval_train", True, "Evaluate on all training data")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
FLAGS = tf.flags.FLAGS
#FLAGS._parse_flags()
#tensorflow新版本没有上述表达了
FLAGS.flag_values_dict()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
# validate
# ==================================================
# validate checkout point file
FLAGS.checkpoint_dir = "./runs/1589886659/checkpoints"
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if checkpoint_file is None:
print("Cannot find a valid checkpoint file!")
exit(0)
print("Using checkpoint file : {}".format(checkpoint_file))
# validate word2vec model file
trained_word2vec_model_file = os.path.join(FLAGS.checkpoint_dir, "..", "trained_word2vec.model")
if not os.path.exists(trained_word2vec_model_file):
print("Word2vec model file \'{}\' doesn't exist!".format(trained_word2vec_model_file))
print("Using word2vec model file : {}".format(trained_word2vec_model_file))
# validate training params file
training_params_file = os.path.join(FLAGS.checkpoint_dir, "..", "training_params.pickle")
if not os.path.exists(training_params_file):
print("Training params file \'{}\' is missing!".format(training_params_file))
print("Using training params file : {}".format(training_params_file))
# Load params
params = data_helpers.loadDict(training_params_file)
num_labels = int(params['num_labels'])
max_document_length = int(params['max_document_length'])
# Load data
if FLAGS.eval_train:
x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.input_text_file, FLAGS.input_label_file, num_labels)
else:
x_raw = ["a masterpiece four years in the making", "everything is off."]
y_test = [1, 0]
# Get Embedding vector x_test
sentences, max_document_length = data_helpers.padding_sentences(x_raw, '<PADDING>', padding_sentence_length = max_document_length)
x_test = np.array(word2vec_helpers.embedding_sentences(sentences, file_to_load = trained_word2vec_model_file))
print("x_test.shape = {}".format(x_test.shape))
# Evaluation
# ==================================================
print("\nEvaluating...\n")
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)
# Collect the predictions here
all_predictions = []
for x_test_batch in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
# Print accuracy if y_test is defined
if y_test is not None:
correct_predictions = float(sum(all_predictions == y_test))
print("Total number of test examples: {}".format(len(y_test)))
print("Accuracy: {:g}".format(correct_predictions/float(len(y_test))))
# Save the evaluation to a csv
predictions_human_readable = np.column_stack((np.array([text for text in x_raw]), all_predictions))
out_path = os.path.join(FLAGS.checkpoint_dir, "..", "prediction.csv")
print("Saving evaluation to {0}".format(out_path))
with open(out_path, 'wb+') as f:
#csv.writer(f).writerows(predictions_human_readable)
w = uccv.writer(f, encoding='gbk')
w.writerows(predictions_human_readable)
修改的地方在于用unicodecsv代替了csv库,并且在open函数以及后面的编码里改成了’gbk’
第三个要注意的是每次调用eval.py时,要注意修改FLAGS.checkpoint_dir,因为原作者调用时用的是SHELL,每次调用会传参,所以在代码里就没有体现,我是自己加了FLAGS.checkpoint_dir的赋值。