word2vec的pytorch实现

词向量表示如word2vec能捕捉词汇间的相似性和类比关系。本文介绍了word2vec的两种假设:CBOW和Skip-gram,重点讨论了Skip-gram模型,它以一个词预测其上下文。并提及在PTB数据集上使用PyTorch实现Skip-gram模型。
摘要由CSDN通过智能技术生成

词向量简介

ont-hot向量表示单词简单,但是不能表现出词语词之间的相似度
word2vec词嵌入可以解决上面的问题。word2vec将词表示成一个定长的向量,然后通过在语料库中的预训练使得这些向量能够学习到词与词之间的相似关系和类比关系。
word2vec有两种基本假设,一种是基于CBOW,另一种是基于Skip-gram。如果是用一个词语作为输入,来预测它周围的上下文,那这个模型叫做Skip-gram 模型而如果是拿一个词语的上下文作为输入,来预测这个词语本身,则是CBOW模型。
下图是
在这里插入图片描述下图是Skip-gram模型在这里插入图片描述

简单来说,word2vec是用一个一层的神经网络把one-hot形式的稀疏词向量映射称为一个n维(n一般为几百)的稠密向量的过程。为了加快模型训练速度,其中的tricks包括Hierarchical softmax,negative sampling, Huffman Tree等

PTB 数据集

这里我们使用经典的 PTB 语料库进行训练。PTB (Penn Tree Bank) 是一个常用的小型语料库,它采样自《华尔街日报》的文章,包括训练集、验证集和测试集。我们将在PTB训练集上训练词嵌入模型。

Skip-gram的pytorch实现

import collections
import math
import random
import sys
import time
import os
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data

with open('ptb.train.txt', 'r') as f:
    lines = f.readlines() # 该数据集中句子以换行符为分割
    raw_dataset = [st.split() for st in lines] # st是sentence的缩写,单词以空格为分割
print('# sentences: %d' % len(raw_dataset))

# 对于数据集的前3个句子,打印每个句子的词数和前5个词
# 句尾符为 '' ,生僻词全用 '' 表示,数字则被替换成了 'N'
for st in raw_dataset[:3]:
    print('# tokens:', len(st), st[:5])

counter = collections.Counter([tk for st in raw_dataset for tk in st]) # tk是token的缩写
counter = dict(filter(lambda x: x[1] >= 5, counter.items())) # 只保留在数据集中至少出现5次的词

idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {
   tk: idx for idx, tk in enumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
           for st in raw_dataset] # raw_dataset中的单词在这一步被转换为对应的idx
num_tokens = sum([len(st) for st in dataset])

#二次采样操作。越高频率的词一般意义不大,根据公式高频词越容易被过滤。准确来说,应该是降频操作。既不希望超高频被完全过滤,又希望减少高频词对训练的影响。
def discard(idx):
    '''
    @params:
        idx: 单词的下标
    @return: True/False 表示是否丢弃该单词
    '''
    return random.uniform(0, 1) < 1 - math.sqrt(
        1e-4 / counter[idx_to_token[idx]] * num_tokens)

subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
print('# tokens: %d' % sum([len(st) for st in subsampled_dataset]))

def compare_counts(token):
#展示了discard操作之后,剩下词的数量
    return '# %s: before=%d, after=%d' % (token, sum(
        [st.count(token_to_idx[token]) for st in dataset]), 
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值