文章目录
前言
词嵌入(Word Embedding)是将词语映射到低维连续向量空间的技术,它能够捕捉词语间的语义和语法关系。预训练词嵌入模型,如 Word2Vec(包括 Skip-gram 和 CBOW)和 GloVe,已经在自然语言处理 (NLP) 领域取得了巨大成功。这些模型通常在大型语料库上进行训练,学习到的词向量可以作为下游 NLP 任务的优秀特征输入。
本文将重点关注如何为预训练词嵌入模型(以 Skip-gram 和负采样为例)准备数据集。我们将使用 Penn Tree Bank (PTB) 数据集,并详细介绍从原始文本数据到可供 PyTorch 模型训练的小批量数据的完整处理流程。这个过程包括读取数据、构建词表、下采样高频词、提取中心词和上下文词、以及进行负采样。通过理解这些步骤,我们可以更好地掌握词嵌入模型训练的基础。
让我们开始吧!
完整代码:下载链接
辅助工具代码
在正式开始数据处理之前,我们先介绍两个辅助 Python 文件,它们分别提供了绘图和数据处理相关的功能。
绘图工具 (utils_for_huitu.py)
这个文件包含了一些使用 Matplotlib 进行绘图的辅助函数,例如设置图像大小、使用 SVG 格式显示以及绘制特定类型的直方图。
# --- START OF FILE utils_for_huitu.py ---
# 导入必要的包
import matplotlib.pyplot as plt # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display # 用于后续动态显示(如 Animator)
import torch # 导入PyTorch库,用于处理张量类型的图像
import numpy as np # 导入NumPy,可能用于数据处理
import matplotlib as mpl # 导入Matplotlib主模块,用于设置图像属性
def set_figsize(figsize=(3.5, 2.5)):
"""
设置matplotlib图形的大小
参数:
figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸
输出:
无返回值
"""
plt.rcParams['figure.figsize'] = figsize # 设置图形默认大小
def use_svg_display():
"""
使用 SVG 格式在 Jupyter 中显示绘图
输入:
无
输出:
无返回值
"""
backend_inline.set_matplotlib_formats('svg') # 设置 Matplotlib 使用 SVG 格式
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
"""
绘制列表长度对的直方图,用于比较两组列表中元素长度的分布
参数:
legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签
xlabel: str - x轴标签
ylabel: str - y轴标签
xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)
ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)
输出:
无返回值,但会显示生成的直方图
"""
set_figsize() # 设置图形大小
# plt.hist返回的三个值:
# n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)
# bins: array - bin的边界值,形状为 (bin数量+1,)
# patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)
_, _, patches = plt.hist(
[[len(l) for l in xlist], [len(l) for l in ylist]]) # 绘制两组数据长度的直方图
plt.xlabel(xlabel) # 设置x轴标签
plt.ylabel(ylabel) # 设置y轴标签
# 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据
for patch in patches[1].patches: # patches[1]是ylist对应的矩形对象列表
patch.set_hatch('/') # 设置填充图案为斜线
plt.legend(legend) # 添加图例
# --- END OF FILE utils_for_huitu.py ---
数据处理工具 (utils_for_data.py)
这个文件包含了一个用于统计词频的函数 count_corpus
和一个核心的 Vocab
类,后者用于构建词表,管理词元到索引以及索引到词元的映射。
# --- START OF FILE utils_for_data.py ---
from collections import Counter # 导入 Counter 类
# from collections import Counter # 用于词频统计 (此行重复,已注释)
import torch # PyTorch 核心库
from torch.utils import data # PyTorch 数据加载工具
import numpy as np # NumPy 用于数组操作
def count_corpus(tokens):
"""
统计词元的频率
参数:
tokens: 词元列表,可以是:
- 一维列表,例如 ['a', 'b']
- 二维列表,例如 [['a', 'b'], ['c']]
返回值:
Counter: Counter 对象,统计每个词元的出现次数
"""
# 如果输入为空列表,直接返回空计数器
if not tokens: # 等价于 len(tokens) == 0
return Counter()
# 检查输入是否为二维列表
if isinstance(tokens[0], list):
# 将二维列表展平为一维列表
flattened_tokens = [token for sublist in tokens for token in sublist]
else:
# 如果是一维列表,直接使用原列表
flattened_tokens = tokens
# 使用 Counter 统计词频并返回
return Counter(flattened_tokens)
class Vocab:
"""文本词表类,用于管理词元及其索引的映射关系"""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
"""初始化词表
Args:
tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表
min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0
reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表
"""
# 处理默认参数
self.tokens = tokens if tokens is not None else []
self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
# 统计词元频率并按频率降序排序
# 注意:这里应该调用类自身的 _count_corpus 方法
counter = self._count_corpus(self.tokens)
self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
# 初始化词表,'<unk>'为未知词元,索引为0
self.idx_to_token = ['<unk>'] + self.reserved_tokens
self.token_to_idx = {
token: idx for idx, token in enumerate(self.idx_to_token)}
# 添加满足最小频率要求的词元到词表
for token, freq in self._token_freqs:
if freq < min_freq:
break
if token not in self.token_to_idx:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
"""
将方法标记为静态方法,无需绑定实例或类,可用类名直接调用
"""
@staticmethod
def _count_corpus(tokens):
"""统计词元频率
Args:
tokens: 词元列表,可以是1D或2D列表
Returns:
Counter对象,统计每个词元的出现次数
"""
if not tokens:
return Counter()
if isinstance(tokens[0], list):
tokens = [token for sublist in tokens for token in sublist]
return Counter(tokens)
def __len__(self):
"""返回词表的大小"""
return len(self.idx_to_token)
def __getitem__(self, tokens):
"""通过词元获取索引,或通过索引获取词元
Args:
tokens: 单个词元或词元列表/元组
Returns:
单个索引或索引列表
"""
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self[token] for token in tokens]
def to_tokens(self, indices):
"""通过索引获取词元
Args:
indices: 单个索引或索引列表/元组
Returns:
单个词元或词元列表
"""
if not isinstance(indices, (list, tuple)):
return self.idx_to_token[indices]
return [self.idx_to_token[index] for index in indices]
"""
用于将类中的方法伪装成属性(property),从而让开发者可以用访问属性的方式(而不是调用方法的方式)来获取或操作类的内部数据
"""
@property
def unk(self):
"""未知词元的索引"""
return 0
@property
def token_freqs(self):
"""词元及其频率的列表"""
return self._token_freqs
# --- END OF FILE utils_for_data.py ---
注意:在 Vocab
类的 __init__
方法中,原代码中 counter = count_corpus(self.tokens)
应该为 counter = self._count_corpus(self.tokens)
或 counter = Vocab._count_corpus(self.tokens)
以调用类自身的静态方法。上述代码已做此修正。
读取数据集 (PTB)
我们使用的数据集是 Penn Tree Bank (PTB)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。
下面的 read_ptb
函数用于将PTB训练集加载到文本行的列表中。
import math
import os
import random
import torch
import numpy as np # 补充可能需要的数值计算库
def read_ptb():
"""
将PTB数据集加载到文本行的列表中
返回:
list[list[str]]: 句子列表,每个句子是词语的列表
形状为 (句子数量, 每句话的词数),
其中每句话的词数不固定
"""
data_dir = 'ptb' # 数据集目录 (str)
# 读取训练集文件
# 假设 'ptb/ptb.train.txt' 文件存在且包含数据
# 为确保代码可运行,如果文件不存在,可以创建一个虚拟文件或处理异常
if not os.path.exists