词嵌入模型一般流程
-
读取文件
-
获取训练集和测试集-
-
生成词典
-
建立模型
-
fit模型训练
-
画出损失率(准确率)函数图
模型训练需要注意参数的类型,shape
通常词嵌入模型可以下载现成的模型矩阵
import json
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
超参数
vocab_size = 1000
embedding_dim = 32
max_length = 16
trunc_type = 'post'
padding_type = 'post'
cov_tok = '<OOV>'
training_size = 20000
读取json文件
with open('/root/jupyter/AI/TensorFlow2/NLP/sarcasm.json','r') as f:
datastore = json.load(f)
sentences = []
labels = []
for item in datastore:
sentences.append(item['headline'])
labels.append(item['is_sarcastic'])
print(len(sentences))
拆分训练集和验证集
training_sentences = sentences[0:training_size]
testing_sentences = sentences[training_size:]
training_labels = labels[0:training_size]
testing_labels = labels[training_size:]
构建词典
tokenizer = Tokenizer(num_words = vocab_size, oov_token = cov_tok)
tokenizer.fit_on_texts(training_sentences)
word_index = tokenizer.word_index
training_sequences = tokenizer.texts_to_sequences(training_sentences)
training_padded = pad_sequences(training_sequences, maxlen=max_length,
padding=padding_type, truncating=trunc_type)
testing_sequences = tokenizer.texts_to_sequences(testing_sentences)
testing_padded = pad_sequences(testing_sequences, maxlen=max_length,
padding=padding_type, truncating=trunc_type)
模型搭建
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size,embedding_dim,input_length=max_length),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(24,activation='relu'),
tf.keras.layers.Dense(1,activation='sigmoid')
])
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()
print(type(testing_labels),len(testing_labels),testing_labels,'\n')
print(type(np.array(testing_labels)),len(np.array(testing_labels)),np.array(testing_labels),'\n')
fit模型
import numpy as np
num_epochs = 30
history = model.fit(training_padded,np.array(training_labels),epochs=num_epochs,
validation_data=(testing_padded,np.array(testing_labels)),verbose=2)
损失图
```python
import matplotlib.pyplot as plt
def plot_graphs(history, string):
plt.plot(history.history[string])
plt.plot(history.history['val_'+string])
plt.xlabel("Epochs")
plt.ylabel(string)
plt.legend([string,'val_'+string])
plt.show()
plot_graphs(history,'accuracy')
plot_graphs(history,'loss')
转载
https://www.bilibili.com/read/cv6490104/