利用LSTM做命名实体识别

在pytorch的官方文档里面,有关于LSTM做命名实体识别的介绍,https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

当然,官方的文档肯定存在一些新手在利用lstm做命名实体识别过程中的一些不全面的地方,我在这里对这些代码进行了补全,但是关于他们的原理就不做多的介绍了。

首先是模型的lstm的搭建

import torch.nn as nn
import torch.nn.functional as F


class LSTM_Model(nn.Module):
    def __init__(self, vocabSize, embedDim, hiddenDim, tagSize):
        super(LSTM_Model, self).__init__()
        self.embeds = nn.Embedding(vocabSize, embedDim)
        self.lstm = nn.LSTM(embedDim, hiddenDim)
        self.hidden2tag = nn.Linear(hiddenDim, tagSize)

    def forward(self, sentSeq):
        embeds = self.embeds(sentSeq)
        output, hidden = self.lstm(embeds.view(len(sentSeq), 1, -1))
        tagSpace = self.hidden2tag(output.view(len(sentSeq), -1))
        result = F.log_softmax(tagSpace, dim=1)
        return result

搭建完成后对模型进行训练,下面是训练模型的代码

model = LSTM_Model(len(word2id), EMBEDDING_DIM, HIDDEN_DIM, len(tag2id))
lossFunction = nn.NLLLoss()
optimzer = optim.SGD(model.parameters(), lr=1e-1)

for epoch in range(300):
    for wordList, tagList in zip(wordLists, tagLists):
        model.zero_grad()  # 清除积累梯度
        input = torch.tensor([word2id[word] for word in wordList])
        tagSeq = torch.tensor([tag2id[tag] for tag in tagList])
        tagScore = model(input)
        loss = lossFunction(tagScore, tagSeq)
        loss.backward()
        optimzer.step()

看看训练后的结果

with torch.no_grad():
    testText = ['欧', '美', '港', '台']
    testSeq = torch.tensor([word2id[word] for word in testText]).long()
    tags_scores = model(testSeq)
    print(tags_scores)
    _, predictId = torch.max(tags_scores, dim=1)
    id2tag = dict((id, tag) for tag, id in tag2id.items())
    tagList = [id2tag[id] for id in predictId.numpy()]
    printZip(testText, tagList)

其中我在下面写了读取训练数据的代码filePath.py

import sys
path=sys.path[0].split('\\')
path.pop(-1)
basePath='/'.join(path)+'/data'

loadText.py

import time
from os import listdir


class loadData():
    def __init__(self):
        pass

    def loadLists(self, filename):
        print('loading data...')
        textLines = open(filename, encoding='utf-8').readlines()
        wordLists = []
        tagLists = []
        wordList = []
        tagList = []
        for textLine in textLines:
            if textLine != '\n':
                word, tag = textLine.strip().split('\t')
                wordList.append(word)
                tagList.append(tag)
            else:
                wordLists.append(wordList)
                tagLists.append(tagList)
                wordList = []
                tagList = []
        print('loading done.')
        return wordLists, tagLists

    def getVocab(self, sentence):
        vocab = {}
        for word in sentence:
            if word not in vocab:
                vocab[word] = len(vocab)
        return vocab

    def text2sentences(self, filename):
        textList = open(filename, encoding='utf-8').read().split('\n')
        sentences = []
        for text in textList:
            sentence = []
            for word in text:
                if word != ' ':
                    sentence.append(word)
            sentences.append(sentence)
        return sentences

    def loadList(self, filename):
        """
        :return: wordList,tagList
        """
        wordList = []
        tagList = []
        textLines = open(filename, encoding='utf-8').readlines()
        for textLine in textLines:
            if textLine != '\n':
                text_list = textLine.strip().split('\t')
                wordList.append(text_list[0])
                tagList.append(text_list[1])
        return wordList, tagList

pprint.py

def printZip(list1, list2):
    pairs = []
    for node1, node2 in zip(list1, list2):
        pairs.append(node1)
        pairs.append(node2)
    print(pairs)

代码就这麽多,训练数据的格式就是这个样子的

放在你的pycharm上就可以用stm模型实现命名实体识别了。

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值