mxnet框架下有一本很好的中文书用来学习机器学习,此处想应用tensorflow来复现书中的模型,
其中有大佬已经在做了github地址在这里
我这里用的是 Mac OS 10.15.5 python 3.7 tensorflow 2.0.0
代码中有一点不同是损失函数用的 BinaryCrossentropy 原文用的是SoftmaxCrossEntropyLoss
完整代码:
# coding: utf-8
import tensorflow as tf
from tensorflow import keras
class BiRNNModel(keras.Model):
def __init__(self, units):
super(BiRNNModel, self).__init__()
self.units = units
self.embedding = keras.layers.Embedding(vocab_size, embedding_size, input_length=max_length)
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(self.units))
self.dense = keras.layers.Dense(1)
def call(self, x, training=None, mask=None):
x = self.embedding(x)
x = self.lstm(x)
x = self.dense(x)
return x
train_times = 3
max_length = 500
embedding_size = 100
embedding_dim = 16
batch_size = 128
vocab_size = 40000 # 词语个数
units, num_classes = 100, 2
epochs = 10
# id的偏移量
index_from = 3
# 在keras上下载数据集
imdb = keras.datasets.imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(
num_words=vocab_size, index_from=index_from)
# 载入词表
word_index = imdb.get_word_index()
word_index = {k: (v + 3) for k, v in word_index.items()}
# id偏移了3之后,就有了特殊的槽位增添特殊字符
word_index['<PAD>'] = 0
word_index['<START>'] = 1
word_index['<UNK>'] = 2
word_index['<END>'] = 3
reverse_word_index = dict(
[(value, key) for key, value in word_index.items()])
# 将id解析成文本
def decode_review(text_ids):
return " ".join(
[reverse_word_index.get(word_id, "<UNK>") for word_id in text_ids])
# decode_review(train_data[0])
# 将数据补充到500维的长度
train_data = keras.preprocessing.sequence.pad_sequences(
train_data,
# 填充值
value=0,
# padding可取“post”和“pre”,post将padding放在句子后面,pre将放前面
padding='post',
maxlen=max_length)
test_data = keras.preprocessing.sequence.pad_sequences(
test_data, value=0, padding='post', maxlen=max_length
)
model = BiRNNModel(units, num_classes, num_layers=2)
model.compile(optimizer=keras.optimizers.Adam(0.001),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, train_labels,
epochs=epochs, batch_size=batch_size,
validation_data=(test_data, test_labels))
model.summary()