我只想报告我的发现,关于装载一个根西姆嵌入Pythorch。Pythorch0.4.0和更新版本的解决方案:
在v0.4.0中有一个新函数^{},这使得加载嵌入非常方便。
以下是文档中的一个示例。>> # FloatTensor containing pretrained weights
>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>> embedding = nn.Embedding.from_pretrained(weight)
>> # Get embeddings for index 1
>> input = torch.LongTensor([1])
>> embedding(input)
来自gensim的权重很容易通过以下方法获得:import gensim
model = gensim.models.KeyedVectors.load_word2vec_format('path/to/file')
weights = torch.FloatTensor(model.vectors) # formerly syn0, which is soon deprecatedPythorch版本0.3.1及更早版本的解决方案:
我使用的是0.3.1版本,而^{}在此版本中不可用。
因此,我创建了自己的from_pretrained,这样我也可以将它与0.3.1一起使用。
PyTorch版本的from_pretrained或更低版本的代码:def from_pretrained(embeddings, freeze=True):
assert embeddings.dim() == 2, \
'Embeddings parameter is expected to be 2-dimensional'
rows, cols = embeddings.shape
embedding = torch.nn.Embedding(num_embeddings=rows, embedding_dim=cols)
embedding.weight = torch.nn.Parameter(embeddings)
embedding.weight.requires_grad = not freeze
return embedding
嵌入可以像这样加载:embedding = from_pretrained(weights)
我希望这对某人有帮助。