1.实现代码
# 构建计算图——LSTM模型
# embedding
# LSTM
# fc
# train_op
# 训练流程代码
# 数据集封装
# api: next_batch(batch_size)
# 词表封装:
# api: sentence2id(text_sentence): 句子转换id
# 类别的封装:
# api: category2id(text_category).
import tensorflow as tf
import os
import sys
import numpy as np
import math
tf.logging.set_verbosity(tf.logging.INFO)
print("ok1")
# 定义数据超参数
def get_default_params():
return tf.contrib.training.HParams(
num_embedding_size = 16, # 词的embedding长度
num_timesteps = 50, # lstm步长,一个句子词的个数
num_lstm_nodes = [32, 32],
num_lstm_layers = 2,
num_fc_nodes = 32,
batch_size = 100,
clip_lstm_grads = 1.0, # 梯度上限
learning_rate = 0.001,
num_word_threshold = 10, # 词频阈值
)
hps = get_default_params()
train_file = 'F:/channelE/lstm/text_classification_data/cnews.train.seg.txt'
val_file = 'F:/channelE/lstm/text_classification_data/cnews.val.seg.txt'
test_file = 'F:/channelE/lstm/text_classification_data/cnews.test.seg.txt'
vocab_file = 'F:/channelE/lstm/text_classification_data/cnews.vocab.txt'
category_file = 'F:/channelE/lstm/text_classification_data/cnews.category.txt'
output_folder = 'F:/channelE/lstm/run_text_rnn'
if not os.path.exists(output_folder):
os.mkdir(output_folder)
print("ok2")
# 词表封装类
class Vocab:
def __init__(self, filename, num_word_threshold):
self._word_to_id = {}
self._unk = -1
self._num_word_threshold = num_word_threshold
self._read_dict(filename)
def _read_dict(self, filename):
with open(filename, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
word, frequency = line.strip('\r\n').split('\t')
# word = word.decode('utf-8')
frequency = int(frequency)
if frequency < self._num_word_threshold:
continue
idx = len(self._word_to_id)
if word == '<UNK>':
self._unk = idx
self._word_to_id[word] = idx
def word_to_id(self, word):
return self._word_to_id.get(word, self._unk)
@property
def unk(self):
return self._unk
def size(self):
return len(self._word_to_id)
def sentence_to_id(self, sentence):
word_ids = [self.word_to_id(cur_word) \
for cur_word in sentence.split()]
return word_ids
# 类别封装
class CategoryDict:
def __init__(self, filename):
self._category_to_id = {}
with open(filename, 'r') as f:
lines = f.readlines()
for line in lines