torchcrf

pip install pytorch-crf
import torch
from torchcrf import CRF
num_tags = 5  # number of tags is 5
model = CRF(num_tags)
seq_length = 3  # maximum sequence length in a batch
batch_size = 2  # number of samples in the batch
emissions = torch.randn(seq_length, batch_size, num_tags)
tags = torch.tensor([[0, 1], [2, 4], [3, 1]], dtype=torch.long)  # (seq_length, batch_size)
model(emissions, tags)

tensor(-9.9695, grad_fn=)

 # mask size is (seq_length, batch_size)
 # the last sample has length of 1
mask = torch.tensor([[1, 1], [0, 1], [1, 1]], dtype=torch.uint8)  #1代码不mask,0代表mask  句子长度3 batch2
model(emissions, tags, mask=mask)   
``
tensor(-10.7121, grad_fn=<SumBackward0>)
```python
model.decode(emissions) 

[[0, 4, 0], [4, 4, 1]]

a=model.decode(emissions)
a=torch.tensor(a)
a.shape

torch.Size([2, 3])

emissions.shape

torch.Size([3, 2, 5])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值