NLP之TF-IDF学习

'''
Description: nlp之TF-IDF学习
Autor: 365JHWZGo
Date: 2021-11-16 14:48:45
LastEditors: 365JHWZGo
LastEditTime: 2021-11-16 18:27:02
'''
import numpy as np
from collections import Counter
import itertools
from visual import show_tfidf

docs = [
    "Born on Oct. 21st, 1992, in Shijiazhuang, Hebei province, Allen Deng is one of the most promising actors of the new generation. ",
    "Brought up by his grandparents, Deng was greatly influenced by his grandpa, who used to be a professor in college.",
    "It was his beloved grandpa who shaped his character and made him what he is today. Good family education nurtured him into a person who knows fairly well how to respond to various situations.",
    "In the second year of his college, as a sophomore, he was chosen by Qiong Yao, a writer in Taiwan who is famous for her romantic novels, to play the role Xu Hao in a TV series Flowers in Fog, marking the beginning of his acting career.",
    "Having graduated from ShanghaiTheatreAcademy, Deng went to Beijing with great passion for acting. He played in some TV series but didn’t get known by many.",
    "it is a good day, I like to stay here",
    "I am happy to be here",
    "I am bob",
    "it is sunny today",
    "I have a party today",
    "it is a dog and that is a cat",
    "there are dog and cat on the tree",
    "I study hard this morning",
    "today is a good day"
]

# 将文档的单词转化为ID形式,这样便于后续通过ID进行统计
docs_words = [d.replace(",", "").split(" ") for d in docs]
vocab = set(itertools.chain(*docs_words))
v2i = {v: i for i, v in enumerate(vocab)}
i2v = {i: v for v, i in v2i.items()}

def safe_log(x):
    mask = x != 0
    x[mask] = np.log(x[mask])
    return x

idf_methods = {
    "log": lambda x: 1+np.log(len(docs)/(x+1)),
    "prob": lambda x: np.maximum(0, np.log((len(docs)-x)/(x+1))),
    "len_norm":lambda x:x/(np.sum(np.square(x))+1),    
}

def get_idf(method="log"):
    df = np.zeros((len(i2v),1))
    for i in range(len(i2v)):
        d_count = 0
        for d in docs_words:
            d_count+=1 if i2v[i] in d else 0
        df[i,0]=d_count
    
    idf_fn = idf_methods.get(method,None)
    if idf_fn is None:
        raise ValueError
    return idf_fn(df)

tf_methods = {
    "log":lambda x:np.log(1+x),
    "augmented":lambda x:0.5+0.5*x/np.max(x,axis=1,keepdims=True),
    "boolean":lambda x:np.minimum(x,1),
    "log_avg":lambda x:(1+safe_log(x))/(1+safe_log(np.mean(x,axis=1,keepdims=True)))    
}

def get_tf(method="log"):
    _tf = np.zeros((len(vocab),len(docs)),dtype=np.float64)
    for i,d in enumerate(docs_words):
        counter = Counter(d)
        for v in counter.keys():
            _tf[v2i[v],i]=counter[v]/counter.most_common(1)[0][1]
    weighted_tf = tf_methods.get(method,None)
    if weighted_tf is None:
        raise ValueError
    return weighted_tf(_tf)

def cosine_similarity(q,_tf_idf):
    unit_q = q/np.sqrt(np.sum(np.square(q),axis=0,keepdims=True))
    unit_ds = _tf_idf/np.sqrt(np.sum(np.square(_tf_idf),axis=0,keepdims=True))
    similarity = unit_ds.T.dot(unit_q).ravel()
    return similarity

def docs_score(q,len_norm=False):
    q_words = q.replace(",","").split(" ")
    unkown_v = 0
    for v in set(q_words):
        if v not in v2i:
            v2i[v] = len(v2i)
            i2v[len(v2i)-1]=v
            unkown_v+=1
    if unkown_v>0:
        _idf = np.concatenate((idf,np.zeros((unkown_v,1),dtype=np.float64)),axis=0)
        _tf_idf = np.concatenate((tf_idf,np.zeros((unkown_v,tf_idf.shape[1]),dtype=np.float64)),axis=0)
    else:
        _idf,_tf_idf = idf,tf_idf
    counter = Counter(q_words)
    q_tf = np.zeros((len(_idf),1),dtype=np.float)
    for v in counter.keys():
        q_tf[v2i[v],0]= counter[v]
    q_vec = q_tf*_idf
    q_scores = cosine_similarity(q_vec,_tf_idf)
    if len_norm:
        len_docs = [len(d) for d in docs_words]
        q_scores = q_scores/np.array(len_docs)
    return q_scores

def get_keywords(n=2):
    for c in range(3):
        col = tf_idf[:,c]
        idx = np.argsort(col)[-n:]
        print("doc{},top{} keywords {}".format(c,n,[i2v[i] for i in idx]))


if __name__ == '__main__':
    tf = get_tf()
    idf = get_idf()
    tf_idf = tf*idf
    print("tf shape(vecb in each docs): ", tf.shape)
    print("\ntf samples:\n", tf[:2])
    print("\nidf shape(vecb in all docs): ", idf.shape)
    print("\nidf samples:\n", idf[:2])
    print("\ntf_idf shape: ", tf_idf.shape)
    print("\ntf_idf sample:\n", tf_idf[:2])

    get_keywords()
    q = 'who is Allen Deng'
    scores = docs_score(q)
    d_ids = scores.argsort()[-3:][::-1]
    print('\ntop 3 docs for "{}":\n{}'.format(q,[docs[i] for i in d_ids]))

    show_tfidf(tf_idf.T,[i2v[i] for i in range(tf_idf.shape[0])],"tfidf_matrix")
    
    

nlp目录结构
在这里插入图片描述
下载的包

conda install requests
conda install pandas

单词分布结果:
在这里插入图片描述
代码运行结果:
在这里插入图片描述

utils.py

'''
Description: utils.py 依赖
Autor: 365JHWZGo
Date: 2021-11-16 16:56:48
LastEditors: 365JHWZGo
LastEditTime: 2021-11-16 17:08:37
'''
import numpy as np
import datetime
import os
import requests
import pandas as pd
import re
import itertools

PAD_ID = 0


class DateData:
    def __init__(self, n):
        np.random.seed(1)
        self.date_cn = []
        self.date_en = []
        for timestamp in np.random.randint(143835585, 2043835585, n):
            date = datetime.datetime.fromtimestamp(timestamp)
            self.date_cn.append(date.strftime("%y-%m-%d"))
            self.date_en.append(date.strftime("%d/%b/%Y"))
        self.vocab = set(
            [str(i) for i in range(0, 10)] + ["-", "/", "<GO>", "<EOS>"] + [
                i.split("/")[1] for i in self.date_en])
        self.v2i = {v: i for i, v in enumerate(sorted(list(self.vocab)), start=1)}
        self.v2i["<PAD>"] = PAD_ID
        self.vocab.add("<PAD>")
        self.i2v = {i: v for v, i in self.v2i.items()}
        self.x, self.y = [], []
        for cn, en in zip(self.date_cn, self.date_en):
            self.x.append([self.v2i[v] for v in cn])
            self.y.append(
                [self.v2i["<GO>"], ] + [self.v2i[v] for v in en[:3]] + [
                    self.v2i[en[3:6]], ] + [self.v2i[v] for v in en[6:]] + [
                    self.v2i["<EOS>"], ])
        self.x, self.y = np.array(self.x), np.array(self.y)
        self.start_token = self.v2i["<GO>"]
        self.end_token = self.v2i["<EOS>"]

    def sample(self, n=64):
        bi = np.random.randint(0, len(self.x), size=n)
        bx, by = self.x[bi], self.y[bi]
        decoder_len = np.full((len(bx),), by.shape[1] - 1, dtype=np.int32)
        return bx, by, decoder_len

    def idx2str(self, idx):
        x = []
        for i in idx:
            x.append(self.i2v[i])
            if i == self.end_token:
                break
        return "".join(x)

    @property
    def num_word(self):
        return len(self.vocab)


def pad_zero(seqs, max_len):
    padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.long)
    for i, seq in enumerate(seqs):
        padded[i, :len(seq)] = seq
    return padded


def maybe_download_mrpc(save_dir="./MRPC/", proxy=None):
    train_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_train.txt'
    test_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_test.txt'
    os.makedirs(save_dir, exist_ok=True)
    proxies = {"http": proxy, "https": proxy}
    for url in [train_url, test_url]:
        raw_path = os.path.join(save_dir, url.split("/")[-1])
        if not os.path.isfile(raw_path):
            print("downloading from %s" % url)
            r = requests.get(url, proxies=proxies)
            with open(raw_path, "w", encoding="utf-8") as f:
                f.write(r.text.replace('"', "<QUOTE>"))
                print("completed")


def _text_standardize(text):
    text = re.sub(r'—', '-', text)
    text = re.sub(r'–', '-', text)
    text = re.sub(r'―', '-', text)
    text = re.sub(r" \d+(,\d+)?(\.\d+)? ", " <NUM> ", text)
    text = re.sub(r" \d+-+?\d*", " <NUM>-", text)
    return text.strip()


def _process_mrpc(dir="./MRPC", rows=None):
    data = {"train": None, "test": None}
    files = os.listdir(dir)
    for f in files:
        df = pd.read_csv(os.path.join(dir, f), sep='\t', nrows=rows)
        k = "train" if "train" in f else "test"
        data[k] = {"is_same": df.iloc[:, 0].values, "s1": df["#1 String"].values, "s2": df["#2 String"].values}
    vocab = set()
    for n in ["train", "test"]:
        for m in ["s1", "s2"]:
            for i in range(len(data[n][m])):
                data[n][m][i] = _text_standardize(data[n][m][i].lower())
                cs = data[n][m][i].split(" ")
                vocab.update(set(cs))
    v2i = {v: i for i, v in enumerate(sorted(vocab), start=1)}
    v2i["<PAD>"] = PAD_ID
    v2i["<MASK>"] = len(v2i)
    v2i["<SEP>"] = len(v2i)
    v2i["<GO>"] = len(v2i)
    i2v = {i: v for v, i in v2i.items()}
    for n in ["train", "test"]:
        for m in ["s1", "s2"]:
            data[n][m+"id"] = [[v2i[v] for v in c.split(" ")] for c in data[n][m]]
    return data, v2i, i2v


class MRPCData:
    num_seg = 3
    pad_id = PAD_ID

    def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
        maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
        data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)
        self.max_len = max(
            [len(s1) + len(s2) + 3 for s1, s2 in zip(
                data["train"]["s1id"] + data["test"]["s1id"], data["train"]["s2id"] + data["test"]["s2id"])])

        self.xlen = np.array([
            [
                len(data["train"]["s1id"][i]), len(data["train"]["s2id"][i])
             ] for i in range(len(data["train"]["s1id"]))], dtype=int)
        x = [
            [self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
            for i in range(len(self.xlen))
        ]
        self.x = pad_zero(x, max_len=self.max_len)
        self.nsp_y = data["train"]["is_same"][:, None]

        self.seg = np.full(self.x.shape, self.num_seg-1, np.int32)
        for i in range(len(x)):
            si = self.xlen[i][0] + 2
            self.seg[i, :si] = 0
            si_ = si + self.xlen[i][1] + 1
            self.seg[i, si:si_] = 1

        self.word_ids = np.array(list(set(self.i2v.keys()).difference(
            [self.v2i[v] for v in ["<PAD>", "<MASK>", "<SEP>"]])))

    def sample(self, n):
        bi = np.random.randint(0, self.x.shape[0], size=n)
        bx, bs, bl, by = self.x[bi], self.seg[bi], self.xlen[bi], self.nsp_y[bi]
        return bx, bs, bl, by

    @property
    def num_word(self):
        return len(self.v2i)

    @property
    def mask_id(self):
        return self.v2i["<MASK>"]


class MRPCSingle:
    pad_id = PAD_ID

    def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
        maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
        data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)

        self.max_len = max([len(s) + 2 for s in data["train"]["s1id"] + data["train"]["s2id"]])
        x = [
            [self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]]
            for i in range(len(data["train"]["s1id"]))
        ]
        x += [
            [self.v2i["<GO>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
            for i in range(len(data["train"]["s2id"]))
        ]
        self.x = pad_zero(x, max_len=self.max_len)
        self.word_ids = np.array(list(set(self.i2v.keys()).difference([self.v2i["<PAD>"]])))

    def sample(self, n):
        bi = np.random.randint(0, self.x.shape[0], size=n)
        bx = self.x[bi]
        return bx

    @property
    def num_word(self):
        return len(self.v2i)


class Dataset:
    def __init__(self, x, y, v2i, i2v):
        self.x, self.y = x, y
        self.v2i, self.i2v = v2i, i2v
        self.vocab = v2i.keys()

    def sample(self, n):
        b_idx = np.random.randint(0, len(self.x), n)
        bx, by = self.x[b_idx], self.y[b_idx]
        return bx, by

    @property
    def num_word(self):
        return len(self.v2i)


def process_w2v_data(corpus, skip_window=2, method="skip_gram"):
    all_words = [sentence.split(" ") for sentence in corpus]
    all_words = np.array(list(itertools.chain(*all_words)))
    # vocab sort by decreasing frequency for the negative sampling below (nce_loss).
    vocab, v_count = np.unique(all_words, return_counts=True)
    vocab = vocab[np.argsort(v_count)[::-1]]

    print("all vocabularies sorted from more frequent to less frequent:\n", vocab)
    v2i = {v: i for i, v in enumerate(vocab)}
    i2v = {i: v for v, i in v2i.items()}

    # pair data
    pairs = []
    js = [i for i in range(-skip_window, skip_window + 1) if i != 0]

    for c in corpus:
        words = c.split(" ")
        w_idx = [v2i[w] for w in words]
        if method == "skip_gram":
            for i in range(len(w_idx)):
                for j in js:
                    if i + j < 0 or i + j >= len(w_idx):
                        continue
                    pairs.append((w_idx[i], w_idx[i + j]))  # (center, context) or (feature, target)
        elif method.lower() == "cbow":
            for i in range(skip_window, len(w_idx) - skip_window):
                context = []
                for j in js:
                    context.append(w_idx[i + j])
                pairs.append(context + [w_idx[i]])  # (contexts, center) or (feature, target)
        else:
            raise ValueError
    pairs = np.array(pairs)
    print("5 example pairs:\n", pairs[:5])
    if method.lower() == "skip_gram":
        x, y = pairs[:, 0], pairs[:, 1]
    elif method.lower() == "cbow":
        x, y = pairs[:, :-1], pairs[:, -1]
    else:
        raise ValueError
    return Dataset(x, y, v2i, i2v)


def set_soft_gpu(soft_gpu):
    import tensorflow as tf
    if soft_gpu:
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")

visual.py

'''
Description: visual.py依赖
Autor: 365JHWZGo
Date: 2021-11-16 16:58:26
LastEditors: 365JHWZGo
LastEditTime: 2021-11-16 17:15:05
'''
import matplotlib.pyplot as plt
import numpy as np
import pickle
from matplotlib.pyplot import cm
import os
import utils


def show_tfidf(tfidf, vocab, filename):
    # [n_doc, n_vocab]
    plt.imshow(tfidf, cmap="YlGn", vmin=tfidf.min(), vmax=tfidf.max())
    plt.xticks(np.arange(tfidf.shape[1]), vocab, fontsize=6, rotation=90)
    plt.yticks(np.arange(tfidf.shape[0]), np.arange(1, tfidf.shape[0]+1), fontsize=6)
    plt.tight_layout()
    plt.savefig("./nlp/images/%s.png" % filename, format="png", dpi=500)
    plt.show()


def show_w2v_word_embedding(model, data: utils.Dataset, path):
    word_emb = model.embeddings.get_weights()[0]
    for i in range(data.num_word):
        c = "blue"
        try:
            int(data.i2v[i])
        except ValueError:
            c = "red"
        plt.text(word_emb[i, 0], word_emb[i, 1], s=data.i2v[i], color=c, weight="bold")
    plt.xlim(word_emb[:, 0].min() - .5, word_emb[:, 0].max() + .5)
    plt.ylim(word_emb[:, 1].min() - .5, word_emb[:, 1].max() + .5)
    plt.xticks(())
    plt.yticks(())
    plt.xlabel("embedding dim1")
    plt.ylabel("embedding dim2")
    plt.savefig(path, dpi=300, format="png")
    plt.show()


def seq2seq_attention():
    with open("./visual/tmp/attention_align.pkl", "rb") as f:
        data = pickle.load(f)
    i2v, x, y, align = data["i2v"], data["x"], data["y"], data["align"]
    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        x_vocab = [i2v[j] for j in np.ravel(x[i])]
        y_vocab = [i2v[j] for j in y[i, 1:]]
        plt.imshow(align[i], cmap="YlGn", vmin=0., vmax=1.)
        plt.yticks([j for j in range(len(y_vocab))], y_vocab)
        plt.xticks([j for j in range(len(x_vocab))], x_vocab)
        if i == 0 or i == 3:
            plt.ylabel("Output")
        if i >= 3:
            plt.xlabel("Input")
    plt.tight_layout()
    plt.savefig("./visual/results/seq2seq_attention.png", format="png", dpi=200)
    plt.show()


def all_mask_kinds():
    seqs = ["I love you", "My name is M", "This is a very long seq", "Short one"]
    vocabs = set((" ".join(seqs)).split(" "))
    i2v = {i: v for i, v in enumerate(vocabs, start=1)}
    i2v["<PAD>"] = 0  # add 0 idx for <PAD>
    v2i = {v: i for i, v in i2v.items()}

    id_seqs = [[v2i[v] for v in seq.split(" ")] for seq in seqs]
    padded_id_seqs = np.array([l + [0] * (6 - len(l)) for l in id_seqs])

    # padding mask
    pmask = np.where(padded_id_seqs == 0, np.ones_like(padded_id_seqs), np.zeros_like(padded_id_seqs))  # 0 idx is padding
    pmask = np.repeat(pmask[:, None, :], pmask.shape[-1], axis=1)  # [n, step, step]
    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
    for i in range(1, 5):
        plt.subplot(2, 2, i)
        plt.imshow(pmask[i-1], vmax=1, vmin=0, cmap="YlGn")
        plt.xticks(range(6), seqs[i - 1].split(" "), rotation=45)
        plt.yticks(range(6), seqs[i - 1].split(" "),)
        plt.grid(which="minor", c="w", lw=0.5, linestyle="-")
    plt.tight_layout()
    plt.savefig("./visual/results/transformer_pad_mask.png", dpi=200)
    plt.show()

    # look ahead mask
    max_len = pmask.shape[-1]
    omask = ~np.triu(np.ones((max_len, max_len), dtype=np.bool), 1)
    omask = np.tile(np.expand_dims(omask, axis=0), [np.shape(seqs)[0], 1, 1])  # [n, step, step]
    omask = np.where(omask, pmask, 1)

    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True
    for i in range(1, 5):
        plt.subplot(2, 2, i)
        plt.imshow(omask[i - 1], vmax=1, vmin=0, cmap="YlGn")
        plt.xticks(range(6), seqs[i - 1].split(" "), rotation=45)
        plt.yticks(range(6), seqs[i - 1].split(" "), )
        plt.grid(which="minor", c="w", lw=0.5, linestyle="-")
    plt.tight_layout()
    plt.savefig("./visual/results/transformer_look_ahead_mask.png", dpi=200)
    plt.show()


def position_embedding():
    max_len = 500
    model_dim = 512
    pos = np.arange(max_len)[:, None]
    pe = pos / np.power(10000, 2. * np.arange(model_dim)[None, :] / model_dim)  # [max_len, model_dim]
    pe[:, 0::2] = np.sin(pe[:, 0::2])
    pe[:, 1::2] = np.cos(pe[:, 1::2])
    plt.imshow(pe, vmax=1, vmin=-1, cmap="rainbow")
    plt.ylabel("word position")
    plt.xlabel("embedding dim")
    plt.savefig("./visual/results/transformer_position_embedding.png", dpi=200)
    plt.show()


def transformer_attention_matrix(case=0):
    with open("./visual/tmp/transformer_attention_matrix.pkl", "rb") as f:
        data = pickle.load(f)
    src = data["src"][case]
    tgt = data["tgt"][case]
    attentions = data["attentions"]

    encoder_atten = attentions["encoder"]
    decoder_tgt_atten = attentions["decoder"]["mh1"]
    decoder_src_atten = attentions["decoder"]["mh2"]
    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True

    plt.figure(0, (7, 7))
    plt.suptitle("Encoder self-attention")
    for i in range(3):
        for j in range(4):
            plt.subplot(3, 4, i * 4 + j + 1)
            plt.imshow(encoder_atten[i][case, j][:len(src), :len(src)], vmax=1, vmin=0, cmap="rainbow")
            plt.xticks(range(len(src)), src)
            plt.yticks(range(len(src)), src)
            if j == 0:
                plt.ylabel("layer %i" % (i+1))
            if i == 2:
                plt.xlabel("head %i" % (j+1))
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.savefig("./visual/results/transformer%d_encoder_self_attention.png" % case, dpi=200)
    plt.show()

    plt.figure(1, (7, 7))
    plt.suptitle("Decoder self-attention")
    for i in range(3):
        for j in range(4):
            plt.subplot(3, 4, i * 4 + j + 1)
            plt.imshow(decoder_tgt_atten[i][case, j][:len(tgt), :len(tgt)], vmax=1, vmin=0, cmap="rainbow")
            plt.xticks(range(len(tgt)), tgt, rotation=90, fontsize=7)
            plt.yticks(range(len(tgt)), tgt, fontsize=7)
            if j == 0:
                plt.ylabel("layer %i" % (i+1))
            if i == 2:
                plt.xlabel("head %i" % (j+1))
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.savefig("./visual/results/transformer%d_decoder_self_attention.png" % case, dpi=200)
    plt.show()

    plt.figure(2, (7, 8))
    plt.suptitle("Decoder-Encoder attention")
    for i in range(3):
        for j in range(4):
            plt.subplot(3, 4, i*4+j+1)
            plt.imshow(decoder_src_atten[i][case, j][:len(tgt), :len(src)], vmax=1, vmin=0, cmap="rainbow")
            plt.xticks(range(len(src)), src, fontsize=7)
            plt.yticks(range(len(tgt)), tgt, fontsize=7)
            if j == 0:
                plt.ylabel("layer %i" % (i+1))
            if i == 2:
                plt.xlabel("head %i" % (j+1))
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.savefig("./visual/results/transformer%d_decoder_encoder_attention.png" % case, dpi=200)
    plt.show()


def transformer_attention_line(case=0):
    with open("./visual/tmp/transformer_attention_matrix.pkl", "rb") as f:
        data = pickle.load(f)
    src = data["src"][case]
    tgt = data["tgt"][case]
    attentions = data["attentions"]

    decoder_src_atten = attentions["decoder"]["mh2"]

    tgt_label = tgt[1:11][::-1]
    src_label = ["" for _ in range(2)] + src[::-1]
    fig, ax = plt.subplots(nrows=2, ncols=2, sharex=True, figsize=(7, 14))

    for i in range(2):
        for j in range(2):
            ax[i, j].set_yticks(np.arange(len(src_label)))
            ax[i, j].set_yticklabels(src_label, fontsize=9)  # src
            ax[i, j].set_ylim(0, len(src_label)-1)
            ax_ = ax[i, j].twinx()
            ax_.set_yticks(np.linspace(ax_.get_yticks()[0], ax_.get_yticks()[-1], len(ax[i, j].get_yticks())))
            ax_.set_yticklabels(tgt_label, fontsize=9)      # tgt
            img = decoder_src_atten[-1][case, i + j][:10, :8]
            color = cm.rainbow(np.linspace(0, 1, img.shape[0]))
            left_top, right_top = img.shape[1], img.shape[0]
            for ri, c in zip(range(right_top), color):      # tgt
                for li in range(left_top):                 # src
                    alpha = (img[ri, li] / img[ri].max()) ** 8
                    ax[i, j].plot([0, 1], [left_top - li + 1, right_top - 1 - ri], alpha=alpha, c=c)
            ax[i, j].set_xticks(())
            ax[i, j].set_xlabel("head %i" % (j + 1 + i * 2))
            ax[i, j].set_xlim(0, 1)
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    plt.savefig("./visual/results/transformer%d_encoder_decoder_attention_line.png" % case, dpi=100)


def self_attention_matrix(bert_or_gpt="bert", case=0):
    with open("./visual/tmp/"+bert_or_gpt+"_attention_matrix.pkl", "rb") as f:
        data = pickle.load(f)
    src = data["src"]
    attentions = data["attentions"]

    encoder_atten = attentions["encoder"]
    plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False
    plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True

    s_len = 0
    for s in src[case]:
        if s == "<SEP>":
            break
        s_len += 1

    plt.figure(0, (7, 28))
    for j in range(4):
        plt.subplot(4, 1, j + 1)
        img = encoder_atten[-1][case, j][:s_len-1, :s_len-1]
        plt.imshow(img, vmax=img.max(), vmin=0, cmap="rainbow")
        plt.xticks(range(s_len-1), src[case][:s_len-1], rotation=90, fontsize=9)
        plt.yticks(range(s_len-1), src[case][1:s_len], fontsize=9)
        plt.xlabel("head %i" % (j+1))
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    plt.savefig("./visual/results/"+bert_or_gpt+"%d_self_attention.png" % case, dpi=500)
    # plt.show()


def self_attention_line(bert_or_gpt="bert", case=0):
    with open("./visual/tmp/"+bert_or_gpt+"_attention_matrix.pkl", "rb") as f:
        data = pickle.load(f)
    src = data["src"][case]
    attentions = data["attentions"]

    encoder_atten = attentions["encoder"]

    s_len = 0
    print(" ".join(src))
    for s in src:
        if s == "<SEP>":
            break
        s_len += 1
    y_label = src[:s_len][::-1]
    fig, ax = plt.subplots(nrows=2, ncols=2, sharex=True, figsize=(7, 14))

    for i in range(2):
        for j in range(2):
            ax[i, j].set_yticks(np.arange(len(y_label)))
            ax[i, j].tick_params(labelright=True)
            ax[i, j].set_yticklabels(y_label, fontsize=9)     # input

            img = encoder_atten[-1][case, i+j][:s_len - 1, :s_len - 1]
            color = cm.rainbow(np.linspace(0, 1, img.shape[0]))
            for row, c in zip(range(img.shape[0]), color):
                for col in range(img.shape[1]):
                    alpha = (img[row, col] / img[row].max()) ** 5
                    ax[i, j].plot([0, 1], [img.shape[1]-col, img.shape[0]-row-1], alpha=alpha, c=c)
            ax[i, j].set_xticks(())
            ax[i, j].set_xlabel("head %i" % (j+1+i*2))
            ax[i, j].set_xlim(0, 1)
    plt.subplots_adjust(top=0.9)
    plt.tight_layout()
    plt.savefig("./visual/results/"+bert_or_gpt+"%d_self_attention_line.png" % case, dpi=100)


if __name__ == "__main__":
    os.makedirs("./visual/results", exist_ok=True)
    # all_mask_kinds()
    # seq2seq_attention()
    # position_embedding()
    transformer_attention_matrix(case=0)
    transformer_attention_line(case=0)

    # model = ["gpt", "bert", "bert_window_mask"][1]
    # case = 6
    # self_attention_matrix(model, case=case)
    # self_attention_line(model, case=case)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值