pip install pytorch-crf
import torch
from torchcrf import CRF
num_tags = 5 # number of tags is 5
model = CRF(num_tags)
seq_length = 3 # maximum sequence length in a batch
batch_size = 2 # number of samples in the batch
emissions = torch.randn(seq_length, batch_size, num_tags)
tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long) # (seq_length, batch_size)
model(emissions, tags)
tensor(-9.9695, grad_fn=)
# mask size is (seq_length, batch_size)
# the last sample has length of 1
mask = torch.tensor([[1, 1], [0, 1], [1, 1]], dtype=torch.uint8) #1代码不mask,0代表mask 句子长度3 batch2
model(emissions, tags, mask=mask)
``
tensor(-10.7121, grad_fn=<SumBackward0>)
```python
model.decode(emissions)
[[0, 4, 0], [4, 4, 1]]
a=model.decode(emissions)
a=torch.tensor(a)
a.shape
torch.Size([2, 3])
emissions.shape
torch.Size([3, 2, 5])