小黑NLPbaseline成长日记1:Skip_Gram+NEG的pytorch练习

import numpy as np
from collections import deque
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import torch.optim as optim
from tqdm import tqdm
import torch
def ArgumentParser():    # 参数基本配置
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name',type = str,default = 'skip-gram',help = 'skip-gram or cbow')
    parser.add_argument('--window_size',type = int,default = 3,help = 'window size in word2vec')
    parser.add_argument('--batch_size',type = int,default = 256,help = 'batch size during training phase')
    parser.add_argument('--min_count',type = int,default = 3,help = 'min count of training word')
    parser.add_argument('--embed_dimension',type = int,default = 100,help = 'embedding dimension of word embedding')
    parser.add_argument('--learning_rate',type = float,default = 0.02,help = 'learning rate during training phase')
    parser.add_argument('--neg-count',type = int,default = 5,help = 'neg count of skip-gram')
    return parser.parse_known_args()[0]
args = ArgumentParser()
WINDOW_SIZE = args.window_size    # 上下文窗口c
BATCH_SIZE = args.batch_size    # mini-batch
MIN_COUNT = args.min_count    # 需要剔除的低频词个数
EMB_DIMENSION = args.embed_dimension    # embedding维度
LR = args.learning_rate    # 学习率
NEG_COUNT = args.neg_count    # 负采样数
# 数据输入类
class InputData:
    def __init__(self,input_file_name,min_count):
        self.input_file_name = input_file_name  
        self.index = 0    # 中心词索引初始化
        self.input_file = open(self.input_file_name,'r',encoding = 'utf-8')    # 输入文档
        self.min_count = min_count    # 最小词频
        self.wordid_frequency_dict = dict()    
        self.word_count = 0    # 词典词语数量
        self.word_count_sum = 0     # 文档的词语总数
        self.sentence_count = 0
        self.id2word_dict = dict()
        self.word2id_dict = dict()
        self._init_dict()    # 初始化字典
        self.sample_table = []     # 采样的词语对儿
        self._init_sample_table()    # 初始化负采样映射表
        self.get_wordId_list()    # 得到文档的词语idlist
        self.word_pairs_queue = deque()    # 初始化队列,便于取batch数据
        # 结果展示
        print('Word Count is:', self.word_count)
        print('Word Count Sum is', self.word_count_sum)
        print('Sentence Count is:', self.sentence_count)
    def _init_dict(self):
        print('word_freq初始化中...')
        word_freq = dict()
        for line in self.input_file:
            line = line.strip().split()
            self.word_count_sum += len(line)
            self.sentence_count += 1
            for i,word in enumerate(line):
                if i % 1000000 == 0:
                    print(i,len(line))
                if word_freq.get(word) == None:
                    word_freq[word] = 1
                else:
                    word_freq[word] += 1
        print('word2id_dict,id2word_dict,wordid_frequency_dict初始化中...')
        for i,word in enumerate(word_freq):
            if i % 100000 == 0:
                print(i,len(word_freq))
            if word_freq[word] < self.min_count:
                self.word_count_sum -= word_freq[word]
                continue
            self.word2id_dict[word] = len(self.word2id_dict)
            self.id2word_dict[len(self.word2id_dict)] = word
            self.wordid_frequency_dict[len(self.word2id_dict)-1] = word_freq[word]
        self.word_count = len(self.word2id_dict)
    def _init_sample_table(self):
        sample_table_size = 1e8
        pow_frequency = np.array(list(self.wordid_frequency_dict.values())) ** 0.75
        word_pow_sum = sum(pow_frequency)
        ratio_array = pow_frequency / word_pow_sum
        word_count_list = np.round(ratio_array * sample_table_size)
        for word_index,word_freq in enumerate(word_count_list):
            self.sample_table += [word_index] * int(word_freq)
        self.sample_table = np.array(self.sample_table)
        np.random.shuffle(self.sample_table)
    def get_wordId_list(self):
        self.input_file = open(self.input_file_name,encoding = 'utf-8')
        sentence = self.input_file.readline()
        wordId_list = []    # 一句中的所有word对应的id
        sentence = sentence.strip().split(' ')
        print('建立wordID_list...')
        for i,word in enumerate(sentence):
            if i % 1000000 == 0:
                print(i,len(sentence))
            try:
                word_id = self.word2id_dict[word]
                wordId_list.append(word_id)
            except:
                continue
        self.wordId_list = wordId_list
    def get_batch_pairs(self,batch_size,window_size):
        while len(self.word_pairs_queue) < batch_size:
            for _ in range(1000):
                if self.index == len(self.wordId_list):
                    self.index = 0
                wordId_w = self.wordId_list[self.index]
                for i in range(max(self.index-window_size,0),min(self.index+window_size,len(self.wordId_list))):
                    wordId_v = self.wordId_list[i]
                    if self.index == i:     # 上下文=中心词 跳过
                        continue
                    self.word_pairs_queue.append((wordId_w,wordId_v))
                self.index += 1
        result_pairs = []    # 返回mini-batch大小的正采样对
        for _ in range(batch_size):
            result_pairs.append(self.word_pairs_queue.popleft())
        return result_pairs   
    # 获得负采样 输入正采样对数组positive_pairs,以及每一个正采样对需要的负采样数neg_count从采样表抽取负采样词的id
    # (假设数据够大,不考虑(负采样=正采样)的小概率情况)
    def get_negative_sampling(self,positive_pairs,neg_count):
        neg_v = np.random.choice(self.sample_table,size = (len(positive_pairs),neg_count)).tolist()
        return neg_v
    # 估计数据中正采样对数,用于设定batch
    def evaluate_pairs_count(self,window_size):
        return self.word_count_sum * (2 * window_size) - self.sentence_count * (1 + window_size) * window_size
class SkipGramModel(nn.Module):
    def __init__(self,vocab_size,embed_size):
        super(SkipGramModel,self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.w_embeddings = nn.Embedding(vocab_size,embed_size)
        self.v_embeddings = nn.Embedding(vocab_size,embed_size)
        self._init_emb()
    def _init_emb(self):
        initrange = 0.5 / self.embed_size
        self.w_embeddings.weight.data.uniform_(-initrange,initrange)
        self.v_embeddings.weight.data.uniform_(-0,0)
    def forward(self,pos_w,pos_v,neg_v):
        emb_w = self.w_embeddings(torch.LongTensor(pos_w).cpu())    # [batch*emb_dim]
        emb_v = self.v_embeddings(torch.LongTensor(pos_v).cpu()) # [batch*emb_dim]
        neg_emb_v = self.v_embeddings(torch.LongTensor(neg_v).cpu()) # [batch*neg_num*emb_dim]
        score = torch.mul(emb_w,emb_v)   
        score = torch.sum(score,dim = 1)
        score = torch.clamp(score,max = 10,min = -10)
        score = F.logsigmoid(score)
        neg_score = torch.bmm(neg_emb_v,emb_w.unsqueeze(2))
        neg_score = torch.clamp(neg_score,max = 10,min = -10)
        neg_score = F.logsigmoid(-1 * neg_score)
        loss = - torch.sum(score) - torch.sum(neg_score)
        return loss
    def save_embedding(self,id2word,file_name):
        embedding_1 = self.w_embeddings.weight.data.cpu().numpy()
        embedding_2 = self.v_embeddings.weight.data.cpu().numpy()
        embedding = (embedding_1 + embedding_2) / 2
        fout = open(file_name,'w')
        fout.write('%d %d \n' % (len(id2word),self.embed_size))
        for wid,w in id2word.items():
            e = embedding[wid]
            e = ' '.join(map(lambda x:str(x),e))
            fout.write('%s %s\n' % (w,e))
class Word2Vec:
    def __init__(self,input_file_name,output_file_name):
        self.output_file_name = output_file_name
        self.data = InputData(input_file_name,MIN_COUNT)
        self.model = SkipGramModel(self.data.word_count,EMB_DIMENSION).cpu()
        self.lr = LR
        self.optimizer = optim.SGD(self.model.parameters(),lr = self.lr)
    def train(self):
        print('SkipGram Training......')
        pairs_count = self.data.evaluate_pairs_count(WINDOW_SIZE)
        print('pairs_count',pairs_count)
        batch_count = pairs_count / BATCH_SIZE
        print('batch_count',batch_count)
        process_bar = tqdm(range(int(5*batch_count)))
        for i in process_bar:
            pos_pairs = self.data.get_batch_pairs(BATCH_SIZE,WINDOW_SIZE)
            pos_w = [int(pair[0]) for pair in pos_pairs]
            pos_v = [int(pair[1]) for pair in pos_pairs]
            neg_v = self.data.get_negative_sampling(pos_pairs,NEG_COUNT)
            self.optimizer.zero_grad()
            loss = self.model.forward(pos_w,pos_v,neg_v)
            loss.backward()
            self.optimizer.step()
            process_bar.set_postfix(loss = loss.data)
            process_bar.update()
        torch.save(self.model.state_dict(),'./test_skipgram_nge.pkl')
        self.model.save_embedding(self.data.id2word_dict,self.output_file_name)
w2v = Word2Vec(input_file_name='./word2vec/data/text8.txt', output_file_name="../results/skip_gram_neg.txt")
w2v.train()
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页