本笔记结合书《TensorFlow知识图谱实战》--王晓华著
一、详细代码
import numpy as np
import tensorflow as tf
labels=[]
context=[]
vocab=set()
with open("ChnSentiCorp.txt",mode='r',encoding='utf-8') as emotion_file:
# with open("test.txt", mode='r', encoding='utf-8') as emotion_file:
for line in emotion_file.readlines():
line=line.strip().split(",")
labels.append(int(line[0]))
# print(line)
# print(len(line))
text=line[1]
# print(text)
# print(type(text))
context.append(line[1])
# print("context:",context)
for char in text:
vocab.add(char)
# print(char)
# print(vocab)
# print(type(vocab))
vocab_list = list(vocab)
# print(vocab_list)
# print(len(vocab_list))
vocab_list = list(sorted(vocab))
# print(vocab_list)
# print(len(vocab_list))
token_list=[ ]
for text in context:
token = [vocab_list.index(char) for char in text ]
# 分解版本:
# print("context:",context)
# print("text:",text)
# token=[ ]
# for char in text:
# # print("char:",char)
# token.append(vocab_list.index(char))
# print("token:",token)
#以80个字符为长度对句子进行截取或补全
# print("len(token);",len(token))
# print("token[:80]+[0]*(80-len(token)):",token[:80]+[0]*(80-len(token)))
token= token[:80]+[0]*(80-len(token))
token_list.append(token)
# print("token_list:",token_list,"\n")
#格式化处理
token_list=np.array(token_list)
# print("token_list:",token_list,"\n")
# print(labels)
labels=np.array(labels)
# print(labels)
input_token = tf.keras.Input(shape=(80,))
embedding = tf.keras.layers.Embedding(input_dim=3508,output_dim=128)(input_token)
embedding = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128))(embedding)
output = tf.keras.layers.Dense(2,activation=tf.nn.softmax)(embedding)
model = tf.keras.Model(input_token,output)
model.compile(optimizer='adam',loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=['accuracy'])
model.fit(token_list,labels,epochs=10,verbose=2)
二、关键思路代码解析
1.读取文件 把句子用逗号分割为词组
for line in emotion_file.readlines():
line=line.strip().split(",")
此时line以列表 [ ]存储
2. 遍历词组中的每个字符 并删除掉重复的set()
vocab=set()
for char in text:
vocab.add(char)
- text是词组
- 遍历每个字符 并查看是否已经在vocab存在 不存在则append
- 此时的vocab是set()类型
3. 用文本的索引号数字来代替符号
简化代码:
for text in context:
token= [ vocab_list.index(char) for char in text ]
思路详细解析:
for text in context:
token=[ ] ---要先定义好为list 否则后面不能append
for char in text:
token.append(vocab_list.index(char)) ---前面排序好的
部分运行结果:
vocab_list:['!', '"', '2', 'M', '。', '一', '下', '不', '个', '也', '了', '人', '体', '公', '加', '务', '去', '品', '商', '多', '大', '太', '好', '实', '宽', '少', '川', '差', '床', '应', '店', '很', '惠', '感', '房', '整', '无', '早', '有', '本', '沙', '济', '的', '离', '经', '视', '觉', '论', '该', '距', '路', '身', '较', '边', '近', '这', '那', '酒', '重', '错', '问', '间', '题', '食', '餐', ',']
context: ['"距离川沙公路较近', '商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!', '早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。']
text: "距离川沙公路较近
char: "
token: [1]
char: 距
token: [1, 49]
char: 离
token: [1, 49, 43]
4.以80个字符为长度对句子进行截取或补全
token= token[:80]+[0]*(80-len(token))
- [0]*(80-len(token):不足80个的剩余空用0补齐
-
例子:token长42 则需要补齐28个0
token: [37, 64, 21, 27, 65, 36, 47, 16, 19, 25, 11, 65, 56, 53, 9, 7, 14, 63, 17, 42, 4, 57, 30, 29, 48, 58, 45, 5, 6, 55, 8, 60, 62, 10, 4, 34, 61, 39, 51, 31, 22, 4]
token[:80]+[0]*(80-len(token)): [37, 64, 21, 27, 65, 36, 47, 16, 19, 25, 11, 65, 56, 53, 9, 7, 14, 63, 17, 42, 4, 57, 30, 29, 48, 58, 45, 5, 6, 55, 8, 60, 62, 10, 4, 34, 61, 39, 51, 31, 22, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
5. 把每个text(每短句)存入token_list --- 一句话一个token
token_list: (举例三句话)
[[1, 49, 43, 26, 40, 13, 50, 52, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[18, 15, 20, 28, 34, 65, 34, 61, 31, 20, 65, 28, 38, 2, 3, 24, 65, 35, 12, 33, 46, 44, 41, 23, 32, 7, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[37, 64, 21, 27, 65, 36, 47, 16, 19, 25, 11, 65, 56, 53, 9, 7, 14, 63, 17, 42, 4, 57, 30, 29, 48, 58, 45, 5, 6, 55, 8, 60, 62, 10, 4, 34, 61, 39, 51, 31, 22, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
6.对数据集进行格式化处理 np.array( )
np.array ( ) : 创建一个数组
token_list=np.array(token_list)
labels=np.array(labels)
结果:前:[1, 1, 1] 后:[1 1 1]
三、代码运行结果
Epoch 1/10
243/243 - 10s - loss: 0.4815 - accuracy: 0.7691
Epoch 2/10
243/243 - 10s - loss: 0.3476 - accuracy: 0.8514
Epoch 3/10
243/243 - 11s - loss: 0.2802 - accuracy: 0.8840
Epoch 4/10
243/243 - 12s - loss: 0.2436 - accuracy: 0.8990
Epoch 5/10
243/243 - 11s - loss: 0.2051 - accuracy: 0.9138
Epoch 6/10
243/243 - 11s - loss: 0.1654 - accuracy: 0.9339
Epoch 7/10
243/243 - 11s - loss: 0.1399 - accuracy: 0.9449
Epoch 8/10
243/243 - 11s - loss: 0.1187 - accuracy: 0.9517
Epoch 9/10
243/243 - 11s - loss: 0.0924 - accuracy: 0.9643
Epoch 10/10
243/243 - 11s - loss: 0.0798 - accuracy: 0.9679