实现固定初始化数据需要准备W,C,bw和bc参数。由于原代码在tf的环境下禁止使用fixed_initialization方法,因此需要编写一个简单的类来封装:
from mittens import GloVe
from mittens.mittens_base import GloVeBase
try:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
except ImportError:
import tensorflow as tf
class myGlove(GloVe):
def __init__(self, n=100, xmax=100, alpha=0.75, max_iter=100,
learning_rate=0.05, tol=1e-4, display_progress=10,
log_dir=None, log_subdir=None, test_mode=False,init=None):
self.init = init
super(GloVeBase, self).__init__(n=n,
xmax=xmax,
alpha=alpha,
mittens=0,
max_iter=max_iter,
learning_rate=learning_rate,
tol=tol,
display_progress=display_progress,
log_dir=log_dir,
log_subdir=log_subdir,
test_mode=test_mode)
def _weight_init(self, m, n, name):
if self.init:
print(init[name])
data = tf.cast(np.loadtxt(init[name]).reshape(m, n),dtype=tf.float32)
with tf.name_scope(name) as scope:
return tf.Variable(data, name=name)
else:
x = np.sqrt(6.0 / (m + n))
with tf.name_scope(name) as scope:
return tf.Variable(
tf.random_uniform(
[m, n], minval=-x, maxval=x), name=name)
init参数是一个字典,形式如下:
init = {
"W":"text\\w.txt",
"C":"text\\c.txt",
"bw":"text\\bw.txt",
"bc":"text\\bc.txt"
}
关于W,C,bw和bc的shape如下:
W = (n_words, n)
C = (n_words, n)
bw = (n_words, 1)
bc = (n_words, 1)
n_words 是词汇数,而n就是GloVe中设置的n。
使用的例子如下:
#coding:utf8
import math
"""本文件包含中文NLP预处理常用的一些代码"""
import re
import string
import numpy as np
import jieba
from gensim import corpora
class Preprocess(object):
"""中文NLP预处理类"""
# 用来处理数据的正则表达式
DIGIT_RE = re.compile(r'\d+')
LETTER_RE = re.compile(r'[a-zA-Z]+')
SPECIAL_SYMBOL_RE = re.compile(r'[^\u4e00-\u9fa5]+') # 用以删除一些特殊符号
STOPS = ['。', '.', '?', '?', '!', '!'] # 中英文句末字符
# 句子所限制的最小,最大长度
SENTENCE_MIN_LEN = 5
SENTENCE_MAX_LEN = 50
def __init__(self,stopwords=None):
self.stopwords = stopwords
@staticmethod
def read_text_file(text_file):
"""读取文本文件,并返回由每行文本作为元素组成的list."""
with open(text_file, 'r', encoding='utf-8') as file:
lines = [line.strip() for line in file]
return lines
@staticmethod
def write_text_file(text_list, target_file):
"""将文本列表写入目标文件
Args:
text_list: 列表,每个元素是一条文本
target_file: 字符串,写入目标文件路径
"""
with open(target_file, 'w', encoding='utf-8') as writer:
for text in text_list:
writer.write(text + '\n')
@staticmethod
def del_blank_lines(sentences):
"""删除句子列表中的空行,返回没有空行的句子列表
Args:
sentences: 字符串列表
"""
data = []
for s in sentences:
s=s.strip()
if s:
data.append(s)
return data
@staticmethod
def del_punctuation(sentence):
"""删除字符串中的中英文标点.
Args:
sentence: 字符串
"""
en_punc_tab = str.maketrans('', '', string.punctuation) # ↓ ① ℃处理不了
sent_no_en_punc = sentence.translate(en_punc_tab)
return re.sub(r'[%s]+' % string.punctuation, " ", sent_no_en_punc)
@staticmethod
def del_stopwords(classsstr, stopwords):
"""删除句子中的停用词
Args:
seg_sents: 嵌套列表,分好词的句子(列表)的列表
stopwords: 停用词列表
Returns: 去除了停用词的句子的列表
"""
stopwords = [line.strip() for line in open(stopwords, encoding='UTF-8').readlines()] # 注意停用词文本名称和所在位置
data = []
for s in classsstr:
s=s.split(" ")
outstr=""
for word in s:
if (word not in stopwords) and (len(word) > 1): # 去掉停用词和长度小于1的
outstr += word
outstr += ' '
data.append(outstr)
return data
@classmethod
def del_special_symbol(cls, sentence):
"""删除句子中的乱码和一些特殊符号。"""
data=[]
for s in sentence:
s=cls.SPECIAL_SYMBOL_RE.sub(' ', s)
if s:
data.append(s)
return data
@classmethod
def del_english_word(cls, sentence):
"""删除句子中的英文字符"""
return cls.LETTER_RE.sub('', sentence)
@staticmethod
def jieba_cut(sentence):
data = []
for s in sentence:
texts =' '.join(jieba.cut(s))
texts = texts.split(" ")
outstr = ""
for s_cut in texts:
if len(s_cut)>1:
outstr = outstr+s_cut+" "
data.append(outstr)
return data
data = []
with open("text/g.txt","r",encoding="utf-8") as f:
for s in f.readlines():
s = re.search("^<content>(.*)</content>$",s)
if s:
data.append(s.group(1))
chn=Preprocess()
data=chn.del_blank_lines(data)
data=chn.del_special_symbol(data)
data=chn.del_stopwords(data,"text\\stop.txt")
data=chn.jieba_cut(data)
data=chn.del_stopwords(data,"text\\stop.txt")
data=data[0:100]
mm=data
print(mm)
texts=[]
for s in data:
t=s.split(" ")
t.pop()#去除最后一个空白字符
texts.append(t)
dict=corpora.Dictionary(texts)
def Bottom_Top(c_pos,max_len,window):
bottom=c_pos-window
top = c_pos+window+1
if bottom < 0:
bottom =0
if top >= max_len:
top=max_len
return bottom,top
n_matrix=4860
window=2
word_matrix=np.zeros(shape=[n_matrix,n_matrix])
id=dict.token2id
for i in range(len(texts)):
k=len(texts[i])
for j in range(k):
bottom,top=Bottom_Top(j,k,window)
c_word=texts[i][j]
c_pos=id[c_word]
for m in range(bottom, top):
#计算矩阵
t_word = texts[i][m]
if m != j and t_word != c_word:
t_pos=id[t_word]
word_matrix[c_pos][t_pos] += 1
from mittens import GloVe
from mittens.mittens_base import GloVeBase
try:
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
except ImportError:
import tensorflow as tf
class myGlove(GloVe):
def __init__(self, n=100, xmax=100, alpha=0.75, max_iter=100,
learning_rate=0.05, tol=1e-4, display_progress=10,
log_dir=None, log_subdir=None, test_mode=False,init=None):
self.init = init
super(GloVeBase, self).__init__(n=n,
xmax=xmax,
alpha=alpha,
mittens=0,
max_iter=max_iter,
learning_rate=learning_rate,
tol=tol,
display_progress=display_progress,
log_dir=log_dir,
log_subdir=log_subdir,
test_mode=test_mode)
def _weight_init(self, m, n, name):
if self.init:
print(init[name])
data = tf.cast(np.loadtxt(init[name]).reshape(m, n),dtype=tf.float32)
with tf.name_scope(name) as scope:
return tf.Variable(data, name=name)
else:
x = np.sqrt(6.0 / (m + n))
with tf.name_scope(name) as scope:
return tf.Variable(
tf.random_uniform(
[m, n], minval=-x, maxval=x), name=name)
init = {
"W":"text\\w.txt",
"C":"text\\c.txt",
"bw":"text\\bw.txt",
"bc":"text\\bc.txt"
}
glove = myGlove(n=10, max_iter=100, learning_rate=0.01,init=init)
G=glove.fit(word_matrix,fixed_initialization=None)