从代码学习深度学习 - 用于预训练词嵌入的数据集 PyTorch版


前言

词嵌入(Word Embedding)是将词语映射到低维连续向量空间的技术,它能够捕捉词语间的语义和语法关系。预训练词嵌入模型,如 Word2Vec(包括 Skip-gram 和 CBOW)和 GloVe,已经在自然语言处理 (NLP) 领域取得了巨大成功。这些模型通常在大型语料库上进行训练,学习到的词向量可以作为下游 NLP 任务的优秀特征输入。

本文将重点关注如何为预训练词嵌入模型(以 Skip-gram 和负采样为例)准备数据集。我们将使用 Penn Tree Bank (PTB) 数据集,并详细介绍从原始文本数据到可供 PyTorch 模型训练的小批量数据的完整处理流程。这个过程包括读取数据、构建词表、下采样高频词、提取中心词和上下文词、以及进行负采样。通过理解这些步骤,我们可以更好地掌握词嵌入模型训练的基础。

让我们开始吧!

完整代码:下载链接

辅助工具代码

在正式开始数据处理之前,我们先介绍两个辅助 Python 文件,它们分别提供了绘图和数据处理相关的功能。

绘图工具 (utils_for_huitu.py)

这个文件包含了一些使用 Matplotlib 进行绘图的辅助函数,例如设置图像大小、使用 SVG 格式显示以及绘制特定类型的直方图。

# --- START OF FILE utils_for_huitu.py ---

# 导入必要的包
import matplotlib.pyplot as plt  # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline  # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display  # 用于后续动态显示(如 Animator)
import torch  # 导入PyTorch库,用于处理张量类型的图像
import numpy as np  # 导入NumPy,可能用于数据处理
import matplotlib as mpl  # 导入Matplotlib主模块,用于设置图像属性

def set_figsize(figsize=(3.5, 2.5)):
    """
    设置matplotlib图形的大小
    
    参数:
        figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸
        
    输出:
        无返回值
    """
    plt.rcParams['figure.figsize'] = figsize  # 设置图形默认大小

def use_svg_display():
    """
    使用 SVG 格式在 Jupyter 中显示绘图
    
    输入:
        无
    输出:
        无返回值
    """
    backend_inline.set_matplotlib_formats('svg')  # 设置 Matplotlib 使用 SVG 格式

def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
    """
    绘制列表长度对的直方图,用于比较两组列表中元素长度的分布
    
    参数:
        legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签
        xlabel: str - x轴标签
        ylabel: str - y轴标签
        xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)
        ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)
    
    输出:
        无返回值,但会显示生成的直方图
    """
    set_figsize()  # 设置图形大小
    
    # plt.hist返回的三个值:
    # n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)
    # bins: array - bin的边界值,形状为 (bin数量+1,)
    # patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)
    _, _, patches = plt.hist(
        [[len(l) for l in xlist], [len(l) for l in ylist]])  # 绘制两组数据长度的直方图
    
    plt.xlabel(xlabel)  # 设置x轴标签
    plt.ylabel(ylabel)  # 设置y轴标签
    
    # 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据
    for patch in patches[1].patches:  # patches[1]是ylist对应的矩形对象列表
        patch.set_hatch('/')  # 设置填充图案为斜线
    
    plt.legend(legend)  # 添加图例
# --- END OF FILE utils_for_huitu.py ---

数据处理工具 (utils_for_data.py)

这个文件包含了一个用于统计词频的函数 count_corpus 和一个核心的 Vocab 类,后者用于构建词表,管理词元到索引以及索引到词元的映射。

# --- START OF FILE utils_for_data.py ---

from collections import Counter  # 导入 Counter 类
# from collections import Counter  # 用于词频统计 (此行重复,已注释)
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作

def count_corpus(tokens):
    """
    统计词元的频率
    
    参数:
        tokens: 词元列表,可以是:
            - 一维列表,例如 ['a', 'b']
            - 二维列表,例如 [['a', 'b'], ['c']]
    
    返回值:
        Counter: Counter 对象,统计每个词元的出现次数
    """
    # 如果输入为空列表,直接返回空计数器
    if not tokens:  # 等价于 len(tokens) == 0
        return Counter()
    
    # 检查输入是否为二维列表
    if isinstance(tokens[0], list):
        # 将二维列表展平为一维列表
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        # 如果是一维列表,直接使用原列表
        flattened_tokens = tokens
    
    # 使用 Counter 统计词频并返回
    return Counter(flattened_tokens)


class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""

    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        """初始化词表

        Args:
            tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表
            min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0
            reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表
        """
        # 处理默认参数
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []

        # 统计词元频率并按频率降序排序
        # 注意:这里应该调用类自身的 _count_corpus 方法
        counter = self._count_corpus(self.tokens) 
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)

        # 初始化词表,'<unk>'为未知词元,索引为0
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}

        # 添加满足最小频率要求的词元到词表
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
    
    """
    将方法标记为静态方法,无需绑定实例或类,可用类名直接调用
    """
    @staticmethod
    def _count_corpus(tokens):
        """统计词元频率

        Args:
            tokens: 词元列表,可以是1D或2D列表

        Returns:
            Counter对象,统计每个词元的出现次数
        """
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        """返回词表的大小"""
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        """通过词元获取索引,或通过索引获取词元

        Args:
            tokens: 单个词元或词元列表/元组

        Returns:
            单个索引或索引列表
        """
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        """通过索引获取词元

        Args:
            indices: 单个索引或索引列表/元组

        Returns:
            单个词元或词元列表
        """
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]
        
    """
    用于将类中的方法伪装成属性(property),从而让开发者可以用访问属性的方式(而不是调用方法的方式)来获取或操作类的内部数据
    """
    @property
    def unk(self):
        """未知词元的索引"""
        return 0

    @property
    def token_freqs(self):
        """词元及其频率的列表"""
        return self._token_freqs
# --- END OF FILE utils_for_data.py ---

注意:在 Vocab 类的 __init__ 方法中,原代码中 counter = count_corpus(self.tokens) 应该为 counter = self._count_corpus(self.tokens)counter = Vocab._count_corpus(self.tokens) 以调用类自身的静态方法。上述代码已做此修正。

读取数据集 (PTB)

我们使用的数据集是 Penn Tree Bank (PTB)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。

下面的 read_ptb 函数用于将PTB训练集加载到文本行的列表中。

import math
import os
import random
import torch
import numpy as np  # 补充可能需要的数值计算库

def read_ptb():
    """
    将PTB数据集加载到文本行的列表中
    
    返回:
        list[list[str]]: 句子列表,每个句子是词语的列表
                         形状为 (句子数量, 每句话的词数),
                         其中每句话的词数不固定
    """
    data_dir = 'ptb'  # 数据集目录 (str)
    # 读取训练集文件
    # 假设 'ptb/ptb.train.txt' 文件存在且包含数据
    # 为确保代码可运行,如果文件不存在,可以创建一个虚拟文件或处理异常
    if not os.path.exists
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值