项目地址
先直接上代码,后面有空再完善。
class BiLSTMCRF(nn.Module):
def __init__(self, num_embeddings, embedding_dim, hidden_size, bidirectional, class_num):
"""
num_embeddings: size of the dictionary of embeddings
embedding_dim: the size of each embedding vector
hidden_size: The number of features in the hidden state `h`
bidirectional: If ``True``, becomes a bidirectional LSTM
class_num: class number
"""
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_size,
batch_first=True, bidirectional=bidirectional)
if bidirectional:
self.classifier = nn.Linear(hidden_size * 2, class_num)
else:
self.classifier = nn.Linear(hidden_size, class_num)
self.crf = CRF(class_num, batch_first=True)
def forward(self, data_index, data_len):
em = self.embedding(data_index)
pack = nn.utils.rnn.pack_padded_sequence(
em, data_len, batch_first=True)
output, _ = self.lstm(pack)
output, _ = nn.utils.rnn.pad_packed_sequence(
output, batch_first=True)
pre = self.classifier(output)
return pre
def loss(self, emissions, tags, mask):
loss = self.crf(emissions, tags, mask)
return -loss
def decode(self, emissions, mask=None):
out = self.crf.decode(emissions, mask)
return out
def fit(self, train_dataloader, epoch, dev_dataloader):
"""
训练模型
"""
lr = 0.001
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
for e in range(epoch):
print("Epoch", f"{e+1}/{epoch}")
self.train()
for data, tag, da_len, mask in train_dataloader:
pred = self.forward(data, da_len)
loss = self.loss(pred, tag, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'loss: {round(loss.item(), 2)}', end='\r')
self.eval()
for data, tag, da_len, mask in dev_dataloader:
tag = nn.utils.rnn.unpad_sequence(
tag, da_len, batch_first=True)
tag = [list(x.cpu().numpy()) for x in tag]
pred = self.forward(data, da_len)
pred = self.decode(pred, mask)
y_pred = list(chain.from_iterable(pred))
y_true = list(chain.from_iterable(tag))
f1 = f1_score(y_true, y_pred, average="micro")
print(f"loss: {round(loss.item(),2)}\tf1: {round(f1,3)}")
def predict(self, word_2_index, index_2_tag, filepath):
self.load_state_dict(torch.load(filepath))
text = input("请输入:")
text_index = [
[word_2_index.get(i, word_2_index["<UNK>"]) for i in text]]
text_index = torch.tensor(text_index, device=device)
text_len = [len(text)]
pred = self.forward(text_index, text_len)
pred = self.decode(pred)
pred = [index_2_tag[i] for i in pred[0]]
print([f'{w}_{s}' for w, s in zip(text, pred)])