# 需要导入模块: import torch [as 别名]
# 或者: from torch import renorm [as 别名]
def load_word2vec_format(filename, word_idx, binary=False, normalize=False,
encoding='utf8', unicode_errors='ignore'):
"""
refer to gensim
load Word Embeddings
If you trained the C model using non-utf8 encoding for words, specify that
encoding in `encoding`.
:param filename :
:param word_idx :
:param binary : a boolean indicating whether the data is in binary word2vec format.
:param normalize:
:param encoding :
:param unicode_errors: errors can be 'strict', 'replace' or 'ignore' and defaults to 'strict'.
"""
vocab = set()
print("loading word embedding from %s" % filename)
with open(filename, 'rb') as fin:
# header = to_unicode(fin.readline(), encoding=encoding)
# vocab_size, vector_size = map(int, header.split()) # throws for invalid file format
vocab_size = 1917494
vector_size = 300
word_matrix = torch.zeros(len(word_idx), vector_size)
def add_word(_word, _weights):
if _word not in word_idx:
return
vocab.add(_word)
word_matrix[word_idx[_word]] = _weights
if binary:
binary_len = np.dtype(np.float32).itemsize * vector_size
for _ in range(vocab_size):
# mixed text and binary: read text first, then binary
word = []
while True:
ch = fin.read(1)
if ch == b' ':
break
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
weights = torch.from_numpy(np.fromstring(fin.read(binary_len), dtype=REAL))
add_word(word, weights)
else:
for line_no, line in enumerate(fin):
parts = to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], list(map(float, parts[1:]))
weights = torch.Tensor(weights)
add_word(word, weights)
if word_idx is not None:
assert (len(word_idx), vector_size) == word_matrix.size()
if normalize:
# each row normalize to 1
word_matrix = torch.renorm(word_matrix, 2, 0, 1)
print("loaded %d words pre-trained from %s with %d" % (len(vocab), filename, vector_size))
return word_matrix, vector_size, vocab