IMDB 数据集介绍
MDB 数据集包含来自互联网电影数据库(IMDB)的 50 000 条严重两极分化的评论。数据集被分为用于训练的 25 000 条评论与用于测试的 25 000 条评论,训练集和测试集都包含 50% 的正面评论和 50% 的负面评论。
train_labels 和 test_labels 都是 0 和 1 组成的列表,其中 0代表负面(negative),1 代表正面(positive)
导入相关的库
import tensorflow.keras as keras
from tensorflow.python.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense ,LSTM,Embedding,Flatten
数据预处理
keras自带了IMDB的数据集,可直接使用load_data进行加载。
(x_train, y_train), (x_test, y_test) = imdb.load_data(path=“imdb.npz”,
num_words=None,
skip_top=0,
maxlen=None,
seed=113,
start_char=1,
oov_char=2,
index_from=3)
返回:
2 个元组:
x_train, x_test: 序列的列表,即词索引的列表。如果指定了 num_words 参数,则可能的最大索引值是 num_words-1。如果指定了 maxlen 参数,则可能的最大序列长度为 maxlen。
y_train, y_test: 整数标签列表 (1 或 0)。
参数:
path: 如果你本地没有该数据集 (在 '~/.keras/datasets/' + path),它将被下载到此目录。
num_words: 整数或 None。要考虑的最常用的词语。任何不太频繁的词将在序列数据中显示为 oov_char 值。
skip_top: 整数。要忽略的最常见的单词(它们将在序列数据中显示为 oov_char 值)。
maxlen: 整数。最大序列长度。 任何更长的序列都将被截断。
seed: 整数。用于可重现数据混洗的种子。
start_char: 整数。序列的开始将用这个字符标记。设置为 1,因为 0 通常作为填充字符。
oov_char: 整数。由于 num_words 或 skip_top 限制而被删除的单词将被替换为此字符。
index_from: 整数。使用此数以上更高的索引值实际词汇索引的开始
(x_train, y_train), (x_test, y_test) =keras.datasets.imdb.load_data(
num_words=10000,
skip_top=0,
start_char=1,
oov_char=2,
index_from=3)
# 截长补短
x_train_s=sequence.pad_sequences(x_train,maxlen=128)
x_test_s=sequence.pad_sequences(x_test,maxlen=128)
模型构建
vocabulary =10000
embedding_dim =32
word_num=128
state_dim =128
model = Sequential()
model.add(Embedding(vocabulary,embedding_dim,input_length=word_num))
model.add(LSTM(state_dim, return_sequences=False))
model.add(Dense(units=1,activation='sigmoid'))
model.summary()
模型训练及预测
#损失函数,优化器,评价指标
model.compile(loss='binary_crossentropy',
optimizer="adam",
metrics=['accuracy'])
history =model.fit(x=x_train_s,y=y_train,validation_split=0.2,epochs=3,batch_size=64,verbose=1)
loss_and_metrics = model.evaluate(x_test_s, y_test, batch_size=128)
绘制训练损失和验证损失
import matplotlib.pyplot as plt
history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']
epochs = range(1, len(loss_values) + 1)
plt.plot(epochs, loss_values, 'bo', label='Training loss')
plt.plot(epochs, val_loss_values, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
绘制训练精度和验证精度
plt.clf() # 清空图像
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()