NLP之skip-gram

'''
Description: skip-gram
Autor: 365JHWZGo
Date: 2021-11-25 17:02:34
LastEditors: 365JHWZGo
LastEditTime: 2021-11-28 18:00:11
'''
import torch
from torch import nn
from torch.nn.functional import cross_entropy
from utils import Dataset, process_w2v_data
from visual import show_w2v_word_embedding
import matplotlib.pyplot as plt

corpus = [
    # numbers
    "5 2 4 8 6 2 3 6 4",
    "4 8 5 6 9 5 5 6",
    "1 1 5 2 3 3 8",
    "3 6 9 6 8 7 4 6 3",
    "8 9 9 6 1 4 3 4",
    "1 0 2 0 2 1 3 3 3 3 3",
    "9 3 3 0 1 4 7 8",
    "9 9 8 5 6 7 1 2 3 0 1 0",

    # alphabets, expecting that 9 is close to letters
    "a t g q e h 9 u f",
    "e q y u o i p s",
    "q o 9 p l k j o k k o p",
    "h g y i u t t a e q",
    "i k d q r e 9 e a d",
    "o p d g 9 s a f g a",
    "i u y g h k l a s w",
    "o l u y a o g f s",
    "o p i u y g d a s j d l",
    "u k i l o 9 l j s",
    "y g i s h k j l f r f",
    "i o h n 9 9 d 9 f a 9",
]


class SkipGram(nn.Module):
    def __init__(self, v_dim, emb_dim):
        super().__init__()
        self.v_dim = v_dim
        self.embeddings = nn.Embedding(v_dim, emb_dim)
        self.embeddings.weight.data.normal_(0, 0.1)
        self.hidden_out = nn.Linear(emb_dim, v_dim)

        self.opt = torch.optim.Adam(self.parameters(), lr=0.01)

    def forward(self, x, training=None, mask=None):
        o = self.embeddings(x)
        return o
    
    def loss(self,x,y,training=None):
        embedded = self(x,training)
        pred = self.hidden_out(embedded)
        return cross_entropy(pred,y)
    
    def step(self,x,y):
        self.opt.zero_grad()
        loss = self.loss(x,y,True)
        loss.backward()
        self.opt.step()
        return loss.detach().numpy()

plt.ion()
plt.show()
if __name__ == '__main__':
    d = process_w2v_data(corpus,skip_window=2,method="skip_gram")
    m = SkipGram(d.num_word,2)
    for step in range(2500):
        bx,by = d.sample(8)
        bx,by = torch.from_numpy(bx),torch.from_numpy(by)
        loss = m.step(bx,by)
        if step%100==0:
            print(f"step:{step} |   loss:{loss}")
            show_w2v_word_embedding(m,d,"./nlp/images/skipgram.png")
    plt.ioff()
    plt.show()

开始时
在这里插入图片描述
进行skip-gram预测后
在这里插入图片描述

step:0 |   loss:3.499077320098877
step:100 |   loss:3.149627685546875
step:200 |   loss:2.8046340942382812
step:300 |   loss:2.7037036418914795
step:400 |   loss:2.573456048965454
step:500 |   loss:3.1753740310668945
step:600 |   loss:3.0493807792663574
step:700 |   loss:2.612898588180542
step:800 |   loss:2.8662161827087402
step:900 |   loss:2.6404876708984375
step:1000 |   loss:2.2015721797943115
step:1100 |   loss:2.0096023082733154
step:1200 |   loss:2.857692003250122
step:1300 |   loss:2.4460296630859375
step:1400 |   loss:2.9422173500061035
step:1500 |   loss:2.8368630409240723
step:1600 |   loss:3.072591781616211
step:1700 |   loss:2.311462879180908
step:1800 |   loss:2.597994327545166
step:1900 |   loss:2.083057403564453
step:2000 |   loss:2.365464210510254
step:2100 |   loss:2.69895601272583
step:2200 |   loss:2.926234722137451
step:2300 |   loss:2.446007490158081
step:2400 |   loss:2.5399281978607178

visual.py

'''
Description: day1118-2
Autor: 365JHWZGo
Date: 2021-11-18 17:05:37
LastEditors: 365JHWZGo
LastEditTime: 2021-11-28 18:00:01
'''
import matplotlib.pyplot as plt

def show_w2v_word_embedding(model,data,path):
    plt.cla()
    word_emb = model.embeddings.weight.data.numpy()
    for i in range(data.num_word):
        c = "blue"
        try:
            int(data.i2v[i])
        except:
            c = "red"
        
        plt.text(word_emb[i,0],word_emb[i,1], s= data.i2v[i], color=c,weight = "bold")
    
    plt.xlim(word_emb[:,0].min() - 0.5, word_emb[:,0].max()+0.5)
    plt.ylim(word_emb[:,1].min() - 0.5, word_emb[:,1].max()+0.5)
    plt.xticks(())
    plt.yticks(())
    plt.xlabel("embedding dim1")
    plt.ylabel("embedding dim2")
    plt.savefig(path,dpi=300,format="png")
    # plt.show()
    plt.pause(0.1)


utils.py

'''
Description: day1118-2依赖
Autor: 365JHWZGo
Date: 2021-11-18 17:03:54
LastEditors: 365JHWZGo
LastEditTime: 2021-11-18 17:04:12
'''
import itertools
import numpy as np
from torch.utils.data import Dataset as tDataset
import datetime
import os
import re
import pandas as pd
import requests
import torch

PAD_ID = 0
class DateData(tDataset):
    def __init__(self,n):
        np.random.seed(1)
        self.date_cn = []
        self.date_en = []
        for timestamp in np.random.randint(143835585, 2043835585, n):
            date = datetime.datetime.fromtimestamp(timestamp)
            self.date_cn.append(date.strftime("%y-%m-%d"))
            self.date_en.append(date.strftime("%d/%b/%Y"))
        self.vocab= set(
            [str(i) for i in range(0,10)] + ["-","/","<GO>","<EOS>"] + [i.split("/")[1] for i in self.date_en]
        )
        self.v2i = {v:i for i,v in enumerate(sorted(list(self.vocab)), start=1)}
        self.v2i["<PAD>"] = PAD_ID
        self.vocab.add("<PAD>")
        self.i2v = {i:v for v,i in self.v2i.items()}
        self.x,self.y=[],[]
        for cn,en in zip(self.date_cn,self.date_en):
            self.x.append([self.v2i[v] for v in cn])
            self.y.append([self.v2i["<GO>"], ] + [self.v2i[v] for v in en[:3]] + [
                self.v2i[en[3:6]]] + [self.v2i[v] for v in en[6:]] + [self.v2i["<EOS>"],])
        self.x,self.y = np.array(self.x),np.array(self.y)
        self.start_token = self.v2i["<GO>"]
        self.end_token = self.v2i["<EOS>"]
    
    def __len__(self):
        return len(self.x)
    
    @property
    def num_word(self):
        return len(self.vocab)
    
    def __getitem__(self, index):
        return self.x[index],self.y[index], len(self.y[index])-1
    
    def idx2str(self,idx):
        x=[]
        for i in idx:
            x.append(self.i2v[i])
            if i == self.end_token:
                break
        return "".join(x)

def pad_zero(seqs, max_len):
    padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.int32)
    for i, seq in enumerate(seqs):
        padded[i, :len(seq)] = seq
    return padded

class Dataset:
    def __init__(self,x,y,v2i,i2v):
        self.x,self.y = x,y
        self.v2i, self.i2v = v2i,i2v
        self.vocab = v2i.keys()
    
    def sample(self,n):
        b_idx = np.random.randint(0,len(self.x),n)
        bx,by = self.x[b_idx],self.y[b_idx]
        return bx,by
    @property
    def num_word(self):
        return len(self.v2i)

def process_w2v_data(corpus,skip_window=2,method = "skip_gram"):
    all_words = [sentence.split(" ") for sentence in corpus]
    # groups all the iterables together and produces a single iterable as output
    all_words = np.array(list(itertools.chain(*all_words)))
    vocab,v_count = np.unique(all_words,return_counts=True)
    vocab = vocab[np.argsort(v_count)[::-1]]
    
    print("All vocabularies are sorted by frequency in decresing oreder")
    v2i = {v:i for i,v in enumerate(vocab)}
    i2v = {i:v for v,i in v2i.items()}

    pairs = []
    js = [i for i in range(-skip_window,skip_window+1) if i!=0]

    for c in corpus:
        words = c.split(" ")
        w_idx = [v2i[w] for w in words]
        if method == "skip_gram":
            for i in range(len(w_idx)):
                for j in js:
                    if i+j<0 or i+j>= len(w_idx):
                        continue
                    pairs.append((w_idx[i],w_idx[i+j]))
        elif method.lower() == "cbow":
            for i in range(skip_window,len(w_idx)-skip_window):
                context = []
                for j in js:
                    context.append(w_idx[i+j])
                pairs.append(context+[w_idx[i]])
        else:
            raise ValueError
    
    pairs = np.array(pairs)
    print("5 expample pairs:\n",pairs[:5])
    if method.lower()=="skip_gram":
        x,y = pairs[:,0],pairs[:,1]
    elif method.lower() == "cbow":
        x,y = pairs[:,:-1],pairs[:,-1]
    else:
        raise ValueError
    return Dataset(x,y,v2i,i2v)

def maybe_download_mrpc(save_dir="./MRPC/", proxy=None):
    train_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_train.txt'
    test_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_test.txt'
    os.makedirs(save_dir, exist_ok=True)
    proxies = {"http": proxy, "https": proxy}
    for url in [train_url, test_url]:
        raw_path = os.path.join(save_dir, url.split("/")[-1])
        if not os.path.isfile(raw_path):
            print("downloading from %s" % url)
            r = requests.get(url, proxies=proxies)
            with open(raw_path, "w", encoding="utf-8") as f:
                f.write(r.text.replace('"', "<QUOTE>"))
                print("completed")


def _text_standardize(text):
    text = re.sub(r'—', '-', text)
    text = re.sub(r'–', '-', text)
    text = re.sub(r'―', '-', text)
    text = re.sub(r" \d+(,\d+)?(\.\d+)? ", " <NUM> ", text)
    text = re.sub(r" \d+-+?\d*", " <NUM>-", text)
    return text.strip()


def _process_mrpc(dir="./MRPC", rows=None):
    data = {"train": None, "test": None}
    files = os.listdir(dir)
    for f in files:
        df = pd.read_csv(os.path.join(dir, f), sep='\t', nrows=rows)
        k = "train" if "train" in f else "test"
        data[k] = {"is_same": df.iloc[:, 0].values, "s1": df["#1 String"].values, "s2": df["#2 String"].values}
    vocab = set()
    for n in ["train", "test"]:
        for m in ["s1", "s2"]:
            for i in range(len(data[n][m])):
                data[n][m][i] = _text_standardize(data[n][m][i].lower())
                cs = data[n][m][i].split(" ")
                vocab.update(set(cs))
    v2i = {v: i for i, v in enumerate(sorted(vocab), start=1)}
    v2i["<PAD>"] = PAD_ID
    v2i["<MASK>"] = len(v2i)
    v2i["<SEP>"] = len(v2i)
    v2i["<GO>"] = len(v2i)
    i2v = {i: v for v, i in v2i.items()}
    for n in ["train", "test"]:
        for m in ["s1", "s2"]:
            data[n][m+"id"] = [[v2i[v] for v in c.split(" ")] for c in data[n][m]]
    return data, v2i, i2v

class MRPCData(tDataset):
    num_seg = 3
    pad_id = PAD_ID

    def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
        maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
        data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)
        self.max_len = max(
            [len(s1) + len(s2) + 3 for s1, s2 in zip(
                data["train"]["s1id"] + data["test"]["s1id"], data["train"]["s2id"] + data["test"]["s2id"])])

        self.xlen = np.array([
            [
                len(data["train"]["s1id"][i]), len(data["train"]["s2id"][i])
             ] for i in range(len(data["train"]["s1id"]))], dtype=int)
        x = [
            [self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
            for i in range(len(self.xlen))
        ]
        self.x = pad_zero(x, max_len=self.max_len)
        self.nsp_y = data["train"]["is_same"][:, None]

        self.seg = np.full(self.x.shape, self.num_seg-1, np.int32)
        for i in range(len(x)):
            si = self.xlen[i][0] + 2
            self.seg[i, :si] = 0
            si_ = si + self.xlen[i][1] + 1
            self.seg[i, si:si_] = 1

        self.word_ids = np.array(list(set(self.i2v.keys()).difference(
            [self.v2i[v] for v in ["<PAD>", "<MASK>", "<SEP>"]])))
    
    def __getitem__(self,idx):
        return self.x[idx], self.seg[idx], self.xlen[idx], self.nsp_y[idx]

    def sample(self, n):
        bi = np.random.randint(0, self.x.shape[0], size=n)
        bx, bs, bl, by = self.x[bi], self.seg[bi], self.xlen[bi], self.nsp_y[bi]
        return bx, bs, bl, by

    @property
    def num_word(self):
        return len(self.v2i)
    
    def __len__(self):
        return len(self.x)

    @property
    def mask_id(self):
        return self.v2i["<MASK>"]

class MRPCSingle(tDataset):
    pad_id = PAD_ID

    def __init__(self,data_dir="./MRPC/",rows = None, proxy= None):
        maybe_download_mrpc(save_dir=data_dir, proxy=proxy)

        data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)

        self.max_len = max([len(s) + 2 for s in data["train"]["s1id"] + data["train"]["s2id"]])
        x = [
            [self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]]
            for i in range(len(data["train"]["s1id"]))
        ]
        x += [
            [self.v2i["<GO>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
            for i in range(len(data["train"]["s2id"]))
        ]
        self.x = pad_zero(x, max_len=self.max_len)
        self.word_ids = np.array(list(set(self.i2v.keys()).difference([self.v2i["<PAD>"]])))
    def sample(self, n):
        bi = np.random.randint(0, self.x.shape[0], size=n)
        bx = self.x[bi]
        return bx

    @property
    def num_word(self):
        return len(self.v2i)
    
    def __getitem__(self, index):
        return self.x[index]

    
    def __len__(self):
        return len(self.x)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值