我尝试了您提出的生成批次的方式 – 有一个循环并使用整个跳过窗口.结果是:
批量生成更快
批量大小为128,跳过窗口为5
>通过逐个循环数据生成批次每10,000批次需要0.73s
>使用教程代码生成批次,并且num_skips = 2每10,000批次需要3.59次
更高的过度配合的危险
保持教程代码的其余部分,我以两种方式训练了模型,并记录了每2000步骤的平均损失:
这种模式反复出现.它显示每个字使用10个样本而不是2个可能会导致过度拟合.
以下是用于生成批次的代码.它替代了教程的generate_batch函数.
data_index = 0
def generate_batch(batch_size, skip_window):
global data_index
batch = np.ndarray(shape=(batch_size), dtype=np.int32) # Row
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) # Column
# For each word in the data, add the context to the batch and the word to the labels
batch_index = 0
while batch_index < batch_size:
context = data[get_context_indices(data_index, skip_window)]
# Add the context to the remaining batch space
remaining_space = min(batch_size - batch_index, len(context))
batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
labels[batch_index:batch_index + remaining_space] = data[data_index]
# Update the data_index and the batch_index
batch_index += remaining_space
data_index = (data_index + 1) % len(data)
return batch, labels