tf.data.Dataset.from_tensor_slices
函数会将输入的张量或列表沿着第一维切片,从而创建一个新的数据集,其中每个元素都是原始张量或列表中的一个切片。
2. sequences = dataset.batch(seq_length+1, drop_remainder=True)
这行代码将dataset
中的元素组合成批次。每个批次的大小是seq_length+1
。drop_remainder=True
意味着如果数据集的大小不是seq_length+1
的整数倍,那么剩余的元素将被丢弃。
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.models import Sequential
# 定义模型
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = Sequential([
Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
Dense(vocab_size)
])
return model
# 构建一个简单的数据集
text = "这是一个简单的例子,用于演示语言模型的基本实现。"
vocab = sorted(set(text))
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])
# 将数据集转换为输入序列和目标序列
def split_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
return input_text, target_text
dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = dataset.batch(seq_length+1, drop_remainder=True)
dataset = sequences.map(split_input_target)
# 设置模型参数
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024
batch_size = 64
# 构建并编译模型
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size)
model.compile(optimizer='adam', loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True))
# 训练模型
model.fit(dataset, epochs=100)
在深度学习中,Dense
是一种用于定义全连接层的Keras层。全连接层也称为密集连接层,它将输入张量与权重矩阵相乘,并添加一个偏置向量,然后通过激活函数激活输出。在语言模型中,全连接层通常用于将模型的最后一层映射到输出词汇表的概率分布。在上面提到的例子中,Dense
层用于模型的输出层,以预测下一个词的概率分布。
Sequential
是Keras中的一个模型容器,它允许你按顺序堆叠各种层以构建神经网络模型。Sequential
模型简单易用,适用于构建简单的线性堆叠模型,其中每一层都只有一个输入张量和一个输出张量。在上面提到的例子中,我们使用Sequential
来定义模型,按顺序添加了嵌入层(Embedding)、LSTM层和全连接层(Dense)。
- 内存分配:数组在内存中是一块连续的空间,因此其索引速度非常快。列表则是由多个对象组成的集合,其元素在内存中不一定是连续的,因此其索引速度可能较慢。
- 列表则主要提供了一些简单的操作,如添加、删除和查找元素等
- 列表和数组各有其优点和适用场景。在需要处理固定大小、固定类型的数据集,并且需要进行大量数学运算时,数组可能是一个更好的选择。而在需要处理动态变化、元素类型多样的数据集时,列表可能更为合适。
from_tensor_slices
是 TensorFlow 数据集 API 中的一个方法,用于从给定的张量(tensor)中创建一个数据集。它会将张量的每个元素视为一个样本,并将它们作为数据集的一部分。
.numpy()
方法将样本转换为 NumPy 数组,以方便打印出来。
这个数组可以更方便地进行索引、切片等操作,并且可以与其他 NumPy 数组进行数学运算。