用Mittens实现Glove算法 - 如何实现初始化固定

实现固定初始化数据需要准备W,C,bw和bc参数。由于原代码在tf的环境下禁止使用fixed_initialization方法,因此需要编写一个简单的类来封装:

from mittens import GloVe
from mittens.mittens_base import GloVeBase
try:
    import tensorflow.compat.v1 as tf
    tf.disable_eager_execution()
except ImportError:
    import tensorflow as tf

class myGlove(GloVe):
    def __init__(self, n=100, xmax=100, alpha=0.75, max_iter=100,
                 learning_rate=0.05, tol=1e-4, display_progress=10,
                 log_dir=None, log_subdir=None, test_mode=False,init=None):
        self.init = init
        super(GloVeBase, self).__init__(n=n,
                                    xmax=xmax,
                                    alpha=alpha,
                                    mittens=0,
                                    max_iter=max_iter,
                                    learning_rate=learning_rate,
                                    tol=tol,
                                    display_progress=display_progress,
                                    log_dir=log_dir,
                                    log_subdir=log_subdir,
                                    test_mode=test_mode)
    def _weight_init(self, m, n, name):
        if self.init:
            print(init[name])
            data = tf.cast(np.loadtxt(init[name]).reshape(m, n),dtype=tf.float32)
            with tf.name_scope(name) as scope:

                return tf.Variable(data, name=name)

        else:
            x = np.sqrt(6.0 / (m + n))
            with tf.name_scope(name) as scope:
                return tf.Variable(
                    tf.random_uniform(
                        [m, n], minval=-x, maxval=x), name=name)

init参数是一个字典,形式如下:

init = {
    "W":"text\\w.txt",
    "C":"text\\c.txt",
    "bw":"text\\bw.txt",
    "bc":"text\\bc.txt"
}

关于W,C,bw和bc的shape如下:

W = (n_words, n)
C = (n_words, n)
bw = (n_words, 1)
bc = (n_words, 1)

n_words 是词汇数,而n就是GloVe中设置的n。

使用的例子如下:

#coding:utf8
import math
"""本文件包含中文NLP预处理常用的一些代码"""

import re
import string
import numpy as np
import jieba
from gensim import corpora

class Preprocess(object):
    """中文NLP预处理类"""

    # 用来处理数据的正则表达式
    DIGIT_RE = re.compile(r'\d+')
    LETTER_RE = re.compile(r'[a-zA-Z]+')
    SPECIAL_SYMBOL_RE = re.compile(r'[^\u4e00-\u9fa5]+')  # 用以删除一些特殊符号
    STOPS = ['。', '.', '?', '?', '!', '!']  # 中英文句末字符

    # 句子所限制的最小,最大长度
    SENTENCE_MIN_LEN = 5
    SENTENCE_MAX_LEN = 50

    def __init__(self,stopwords=None):
        self.stopwords = stopwords

    @staticmethod
    def read_text_file(text_file):
        """读取文本文件,并返回由每行文本作为元素组成的list."""
        with open(text_file, 'r', encoding='utf-8') as file:
            lines = [line.strip() for line in file]
        return lines

    @staticmethod
    def write_text_file(text_list, target_file):
        """将文本列表写入目标文件

        Args:
            text_list: 列表,每个元素是一条文本
            target_file: 字符串,写入目标文件路径
        """
        with open(target_file, 'w', encoding='utf-8') as writer:
            for text in text_list:
                writer.write(text + '\n')

    @staticmethod
    def del_blank_lines(sentences):
        """删除句子列表中的空行,返回没有空行的句子列表

        Args:
            sentences: 字符串列表
        """
        data = []
        for s in sentences:
            s=s.strip()
            if s:
                data.append(s)
        return data

    @staticmethod
    def del_punctuation(sentence):
        """删除字符串中的中英文标点.

        Args:
            sentence: 字符串
        """
        en_punc_tab = str.maketrans('', '', string.punctuation)  # ↓ ① ℃处理不了
        sent_no_en_punc = sentence.translate(en_punc_tab)
        return re.sub(r'[%s]+' % string.punctuation, " ", sent_no_en_punc)

    @staticmethod
    def del_stopwords(classsstr, stopwords):
        """删除句子中的停用词

        Args:
            seg_sents: 嵌套列表,分好词的句子(列表)的列表
            stopwords: 停用词列表

        Returns: 去除了停用词的句子的列表
        """

        stopwords = [line.strip() for line in open(stopwords, encoding='UTF-8').readlines()]  # 注意停用词文本名称和所在位置

        data = []
        for s in classsstr:
            s=s.split(" ")
            outstr=""
            for word in s:
                if (word not in stopwords) and (len(word) > 1):  # 去掉停用词和长度小于1的
                    outstr += word
                    outstr += ' '
            data.append(outstr)
        return data

    @classmethod
    def del_special_symbol(cls, sentence):
        """删除句子中的乱码和一些特殊符号。"""
        data=[]
        for s in sentence:
            s=cls.SPECIAL_SYMBOL_RE.sub(' ', s)
            if s:
                data.append(s)
        return data

    @classmethod
    def del_english_word(cls, sentence):
        """删除句子中的英文字符"""
        return cls.LETTER_RE.sub('', sentence)
    @staticmethod
    def jieba_cut(sentence):
        data = []
        for s in sentence:
            texts =' '.join(jieba.cut(s))
            texts = texts.split(" ")
            outstr = ""
            for s_cut in texts:
                if len(s_cut)>1:
                    outstr = outstr+s_cut+" "
            data.append(outstr)

        return data
data = []
with open("text/g.txt","r",encoding="utf-8") as f:
    for s in f.readlines():
        s = re.search("^<content>(.*)</content>$",s)
        if s:
            data.append(s.group(1))
chn=Preprocess()
data=chn.del_blank_lines(data)
data=chn.del_special_symbol(data)
data=chn.del_stopwords(data,"text\\stop.txt")
data=chn.jieba_cut(data)
data=chn.del_stopwords(data,"text\\stop.txt")
data=data[0:100]
mm=data
print(mm)
texts=[]
for s in data:
    t=s.split(" ")
    t.pop()#去除最后一个空白字符
    texts.append(t)
dict=corpora.Dictionary(texts)

def Bottom_Top(c_pos,max_len,window):
    bottom=c_pos-window
    top = c_pos+window+1
    if bottom < 0:
        bottom =0
    if top >= max_len:
        top=max_len
    return bottom,top

n_matrix=4860
window=2
word_matrix=np.zeros(shape=[n_matrix,n_matrix])
id=dict.token2id
for i in range(len(texts)):
    k=len(texts[i])
    for j in range(k):
        bottom,top=Bottom_Top(j,k,window)
        c_word=texts[i][j]
        c_pos=id[c_word]
        for m in range(bottom, top):
            #计算矩阵
            t_word = texts[i][m]
            if m != j and t_word != c_word:
                t_pos=id[t_word]
                word_matrix[c_pos][t_pos] += 1

from mittens import GloVe
from mittens.mittens_base import GloVeBase
try:
    import tensorflow.compat.v1 as tf
    tf.disable_eager_execution()
except ImportError:
    import tensorflow as tf

class myGlove(GloVe):
    def __init__(self, n=100, xmax=100, alpha=0.75, max_iter=100,
                 learning_rate=0.05, tol=1e-4, display_progress=10,
                 log_dir=None, log_subdir=None, test_mode=False,init=None):
        self.init = init
        super(GloVeBase, self).__init__(n=n,
                                    xmax=xmax,
                                    alpha=alpha,
                                    mittens=0,
                                    max_iter=max_iter,
                                    learning_rate=learning_rate,
                                    tol=tol,
                                    display_progress=display_progress,
                                    log_dir=log_dir,
                                    log_subdir=log_subdir,
                                    test_mode=test_mode)
    def _weight_init(self, m, n, name):
        if self.init:
            print(init[name])
            data = tf.cast(np.loadtxt(init[name]).reshape(m, n),dtype=tf.float32)
            with tf.name_scope(name) as scope:

                return tf.Variable(data, name=name)

        else:
            x = np.sqrt(6.0 / (m + n))
            with tf.name_scope(name) as scope:
                return tf.Variable(
                    tf.random_uniform(
                        [m, n], minval=-x, maxval=x), name=name)
init = {
    "W":"text\\w.txt",
    "C":"text\\c.txt",
    "bw":"text\\bw.txt",
    "bc":"text\\bc.txt"
}
glove = myGlove(n=10, max_iter=100, learning_rate=0.01,init=init)
G=glove.fit(word_matrix,fixed_initialization=None)


  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值