1. CRF的定义
条件随机场是给定随机变量 X X X 条件下、随机变量 Y Y Y 的马尔可夫随机场,也即随机变量 Y Y Y 构成一个由无向图 G = ( V , E ) G=(V,E) G=(V,E) 表示的马尔可夫随机场。
即满足 P ( Y v ∣ X , Y w , w ≠ v ) = P ( Y v ∣ X , Y w , w ∼ v ) P(Y_v|X,Y_w, w \neq v) = P(Y_v|X,Y_w, w \sim v) P(Yv∣X,Yw,w=v)=P(Yv∣X,Yw,w∼v) 对任意节点 v v v 成立,则称条件概率分布 P ( Y ∣ X ) P(Y|X) P(Y∣X) 为条件随机场。
式中,
- w ∼ v w \sim v w∼v 表示与节点 v v v 有边连接的所有节点 w w w
- w ≠ v w \neq v w=v 表示节点 v v v 以外的所有节点
线性链CRF
在一般CRF的定义中,并没有要求 X 、 Y X、Y X、Y 具有相同的结构,但在解决序列标注、命名体识别等任务时,一般假设 X 、 Y X、Y X、Y 有相同的图结构。
在这种情况下,
P
(
Y
∣
X
)
P(Y|X)
P(Y∣X) 构成条件随机场,满足马尔科夫性
P
(
Y
i
∣
X
,
Y
1
,
⋯
,
Y
i
−
1
,
Y
i
+
1
,
⋯
,
Y
n
)
=
P
(
Y
i
∣
X
,
Y
i
−
1
,
Y
i
+
1
)
P(Y_i|X,Y_1,\cdots,Y_{i-1},Y_{i+1},\cdots,Y_n) = P(Y_i|X,Y_{i-1},Y_{i+1})
P(Yi∣X,Y1,⋯,Yi−1,Yi+1,⋯,Yn)=P(Yi∣X,Yi−1,Yi+1)
在实际任务中,
- X X X 表示观测序列, Y Y Y 表示对应的输出序列或状态序列
- 训练时,利用通过正则化后的极大似然估计学习 P ( Y ∣ X ) P(Y|X) P(Y∣X)
- 预测时,对于给定的输入序列 x x x,求出条件概率 P ( y ∣ x ) P(y|x) P(y∣x) 最大的输出序列 y y y
2. 条件概率的计算
根据线性链条件随机场的定义,给定观测序列为
x
x
x,输出序列为
y
y
y 的条件概率为
p
(
y
∣
x
)
=
e
x
p
[
S
c
o
r
e
(
x
,
y
)
]
∑
y
′
e
x
p
[
S
c
o
r
e
(
x
,
y
′
)
]
p(y|x) = \frac{exp[Score(x,y)]}{\sum_{y'}exp[Score(x,y')]}
p(y∣x)=∑y′exp[Score(x,y′)]exp[Score(x,y)]
其中,
S
c
o
r
e
(
x
,
y
)
Score(x,y)
Score(x,y) 的定义为
S
c
o
r
e
(
x
,
y
)
=
∑
i
=
1
n
ψ
i
(
x
,
y
)
=
∑
i
=
1
n
[
t
r
a
n
s
(
y
i
−
1
,
y
i
)
+
e
m
i
t
(
x
i
,
y
i
)
]
=
∑
i
=
1
n
[
A
y
i
−
1
,
y
i
+
P
x
i
,
y
i
]
\begin{aligned} Score(x,y) &= \sum_{i=1}^{n} \psi_i(x,y) \\ &= \sum_{i=1}^{n} [trans(y_{i-1},y_{i})+emit(x_i,y_i)] \\ &= \sum_{i=1}^{n} [A_{y_{i-1},y_i}+P_{x_i,y_i}] \end{aligned}
Score(x,y)=i=1∑nψi(x,y)=i=1∑n[trans(yi−1,yi)+emit(xi,yi)]=i=1∑n[Ayi−1,yi+Pxi,yi]
势函数 ψ ( x , y ) \psi(x,y) ψ(x,y) 分为两部分,转移函数和发射函数。转移函数表示输出序列中,由上一时刻转化为当前时刻值的概率,发射函数表示当前时刻的观测值对应当前时刻输出值的概率。
在训练过程中,对上方的条件概率取对数,得到
l
o
g
[
p
(
y
∣
x
)
]
=
S
c
o
r
e
(
x
,
y
)
−
l
o
g
[
∑
y
′
e
x
p
(
S
c
o
r
e
(
x
,
y
′
)
)
]
log[p(y|x)]=Score(x,y)-log[\sum_{y'} exp(Score(x,y'))]
log[p(y∣x)]=Score(x,y)−log[y′∑exp(Score(x,y′))]
上方表达式中第一项的计算比较简单,重点是第二项,这里称其为 l o g s u m log\ sum log sum,需要遍历所有可能的输出序列对应的分数。
Log-Sum的计算
记序列 y y y 总共有 T T T 种状态,截止到时刻 i i i 序列的分数为 S i ( x , y ) S_i(x,y) Si(x,y),截至到该时刻、最后一个状态 y i = s j y_i=s_j yi=sj 的log sum总分数是 f ( i , s j ) f(i,s_j) f(i,sj) 。
于是根据公式可得:
f
(
i
,
s
j
)
=
l
o
g
∑
y
′
∈
Y
i
−
1
,
y
i
=
s
j
e
x
p
[
S
i
(
x
,
y
′
,
s
j
)
]
=
l
o
g
∑
t
=
1
T
∑
y
′
∈
Y
i
−
2
,
y
i
−
1
=
s
t
,
y
i
=
s
j
e
x
p
[
S
i
(
x
,
y
′
,
s
t
,
s
j
)
]
=
l
o
g
∑
t
=
1
T
∑
y
′
∈
Y
i
−
2
,
y
i
−
1
=
s
t
e
x
p
[
S
i
−
1
(
x
,
y
′
,
s
t
)
+
A
s
t
,
s
j
+
P
x
i
,
s
j
]
=
l
o
g
∑
t
=
1
T
e
x
p
[
A
s
t
,
s
j
+
P
x
i
,
s
j
]
∗
∑
y
′
∈
Y
i
−
2
,
y
i
−
1
=
s
t
e
x
p
[
S
i
−
1
(
x
,
y
′
,
s
t
)
]
=
l
o
g
∑
t
=
1
T
e
x
p
[
A
s
t
,
s
j
+
P
x
i
,
s
j
]
∗
e
x
p
(
l
o
g
∑
y
′
∈
Y
i
−
2
,
y
i
−
1
=
s
t
e
x
p
[
S
i
−
1
(
x
,
y
′
,
s
t
)
)
]
=
l
o
g
∑
t
=
1
T
e
x
p
[
A
s
t
,
s
j
+
P
x
i
,
s
j
]
∗
e
x
p
[
f
(
i
−
1
,
s
t
)
]
=
l
o
g
∑
t
=
1
T
e
x
p
[
f
(
i
−
1
,
s
t
)
+
A
s
t
,
s
j
+
P
x
i
,
s
j
]
\begin{aligned} f(i,s_j) &= log \sum_{y'\in Y_{i-1}, y_i=s_j} exp[S_i(x,y',s_j)] \\ &= log \sum_{t=1}^T \sum_{y'\in Y_{i-2}, y_{i-1}=s_t,y_i=s_j} exp[S_i(x,y',s_t,s_j)] \\ &= log \sum_{t=1}^T \sum_{y'\in Y_{i-2},y_{i-1}=s_t} exp[S_{i-1}(x,y',s_t)+A_{s_t,s_j}+P_{x_i,s_j}] \\ &= log \sum_{t=1}^T exp[A_{s_t,s_j}+P_{x_i,s_j}] * \sum_{y'\in Y_{i-2},y_{i-1}=s_t} exp[S_{i-1}(x,y',s_t)] \\ &= log \sum_{t=1}^T exp[A_{s_t,s_j}+P_{x_i,s_j}] * exp(log\sum_{y'\in Y_{i-2},y_{i-1}=s_t} exp[S_{i-1}(x,y',s_t))] \\ &= log \sum_{t=1}^T exp[A_{s_t,s_j}+P_{x_i,s_j}] * exp[f(i-1,s_t)] \\ &= log \sum_{t=1}^T exp[f(i-1,s_t)+A_{s_t,s_j}+P_{x_i,s_j}] \end{aligned}
f(i,sj)=logy′∈Yi−1,yi=sj∑exp[Si(x,y′,sj)]=logt=1∑Ty′∈Yi−2,yi−1=st,yi=sj∑exp[Si(x,y′,st,sj)]=logt=1∑Ty′∈Yi−2,yi−1=st∑exp[Si−1(x,y′,st)+Ast,sj+Pxi,sj]=logt=1∑Texp[Ast,sj+Pxi,sj]∗y′∈Yi−2,yi−1=st∑exp[Si−1(x,y′,st)]=logt=1∑Texp[Ast,sj+Pxi,sj]∗exp(logy′∈Yi−2,yi−1=st∑exp[Si−1(x,y′,st))]=logt=1∑Texp[Ast,sj+Pxi,sj]∗exp[f(i−1,st)]=logt=1∑Texp[f(i−1,st)+Ast,sj+Pxi,sj]
因此,当前时刻的log sum值可通过前一时刻的log sum值得到。
3. 代码实现
结合上述公式,基于Pytorch框架实现CRF类,并实现向前传播与解码过程。
import torch
import torch.nn as nn
from typing import Optional, List
class CRF(nn.Module):
"""
This module implements a conditional random field.
The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor.
This class also has 'decode' method which finds the best tag sequence given an emission score tensor using 'Viterbi algorithm'.
"""
def __init__(self, num_tags: int, batch_first: bool = False):
"""
Args:
:param num_tags: Number of tags.
:param batch_first: Whether the first dimension corresponds to the size of a minibatch.
Attributes:
start_transitions: Start transition score tensor of size (num_tags, ).
end_transitions: End transition score tensor of size (num_tags, ).
transitions: Transition matrix tensor of size (num_tags, num_tags).
"""
if num_tags <= 0:
raise ValueError(f'num_tags must be positive.')
super(CRF, self).__init__()
self.num_tags = num_tags
self.batch_first = batch_first
self.start_transitions = nn.Parameter(torch.empty(num_tags))
self.end_transitions = nn.Parameter(torch.empty(num_tags))
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
# initialize transition parameters
self.reset_parameters()
def __repr__(self):
# Description of the class.
return f'{self.__class__.__name__}(num_tags={self.num_tags})'
def reset_parameters(self):
"""
Initialize the transition paramters.
The parameters will be initialized randomly from a uniform distribution between -0.1 and 0.1
"""
nn.init.uniform_(self.start_transitions)
nn.init.uniform_(self.end_transitions)
nn.init.uniform_(self.transitions)
def forward(
self,
emissions: torch.Tensor,
tags: torch.LongTensor,
mask: Optional[torch.ByteTensor] = None,
reduction: str = 'sum'
) -> torch.Tensor:
"""
Args:
:param emissions: Emission score tensor of size (batch_sz, seq_len, num_tags) if batch_first else (seq_len, batch_sz, num_tags)
:param tags: Sequence of tags tensor of size (batch_sz, seq_len) if batch_first else (seq_len, batch_sz)
:param mask: Mask tensor of size (batch_sz, seq_len) if batch_first else (seq_len, batch_sz)
:param reduction: Specifies the reduction from 'none/sum/mean/token_mean' to apply the output.
Returns:
:return The log-likelyhood score of shape (batch_size, ) if reduction='none' else ().
"""
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'Invalid Reduction: {reduction}')
if mask is None:
mask = torch.ones_like(tags, dtype=torch.uint8)
if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
# (batch_sz, )
numerator = self._compute_score(emissions, tags, mask)
# (batch_sz, )
denominator = self._compute_log_sum_score(emissions, mask)
# (batch_sz, )
llh = numerator - denominator # log-likelyhood
if reduction == 'none': # not apply reduction
return llh
if reduction == 'sum': # summed over batches
return llh.sum()
if reduction == 'mean': # average over batches
return llh.mean()
if reduction == 'token_mean': # average over tokens
return llh.sum() / mask.float().sum()
def decode(self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
"""
Find the most likely tag sequence using Viterbi algorithm
Args:
:param emissions: (seq_len, batch_size, num_tags)
:param mask: (seq_len, batch_size)
Return:
:return List of list containing the best tag sequence for each batch.
"""
if mask is None:
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)
return self._viterbi_decode(emissions, mask)
def _compute_score(self, emissions, tags, mask):
"""
:param emissions: (seq_len, batch_size, num_tags)
:param tags: (seq_len, batch_size)
:param mask: (seq_len, batch_size)
"""
seq_len = tags.shape[0]
mask = mask.float()
# Start transition and emission score
# (batch_sz,)
score = self.start_transitions[tags[0]] + emissions[0, :, tags[0]]
for t in range(1, seq_len):
# Transition score to next tag
score += self.transitions[tags[t - 1], tags[t]] * mask[t]
# Emission score for next tag
score += emissions[t, :, tags[t]] * mask[t]
# Sequence end position with shape (batch_sz,)
seq_end_idx = mask.sum(dim=0) - 1
score += self.end_transitions[tags[seq_end_idx]]
# (batch_sz,)
return score
def _compute_log_sum_score(self, emissions, mask):
"""
:param emissions: (seq_len, batch_size, num_tags)
:param mask: (seq_len, batch_size)
"""
seq_len = mask[0]
mask = mask.float()
# (batch_sz, num_tags)
score = self.start_transitions + emissions[0]
# compute sum of all scores end with tag i to tag j till each timestep.
for t in range(1, seq_len):
# (batch_sz, num_tags) -> (batch_sz, num_tags, 1)
broad_score = score.unsqueeze(2)
# (batch_sz, num_tags) -> (batch_sz, 1, num_tags)
broad_emit = emissions[t].unsqueeze(1)
# f(i, s_j) = log-sum-exp[f(i-1,s_t)+trans(s_t,s_j)+emit(s_j)]
# (batch_sz, num_tags, num_tags)
next_score = broad_score + self.transitions.unsqueeze(0) + broad_emit
# (batch_sz, num_tags, num_tags) -> (batch_sz, num_tags)
next_score = torch.logsumexp(next_score, dim=1)
score = torch.where(mask[t], next_score, score)
# (batch_sz, num_tags)
score += self.end_transitions
# (batch_sz, num_tags) -> (batch_sz,)
score = torch.logsumexp(score, 1)
return score
def _viterbi_decode(self, emissions, mask) -> List[List[int]]:
"""
:param emissions: (seq_len, batch_size, num_tags)
:param mask: (seq_len, batch_size)
:return List of list containing the best tag sequence for each batch.
"""
seq_len, batch_sz = mask.shape
history = []
# (batch_sz, num_tags)
score = self.start_transitions + emissions[0]
# compute score of best sequence end with tag i to tag j till each timestep.
for t in range(1, seq_len):
# (batch_sz, num_tags) -> (batch_sz, num_tags, 1)
broad_score = score.unsqueeze(2)
# (batch_sz, num_tags) -> (batch_sz, 1, num_tags)
broad_emit = emissions[t].unsqueeze(1)
# (batch_sz, num_tags, num_tags)
next_score = broad_score + self.transitions.unsqueeze(0) + broad_emit
# (batch_sz, num_tags, num_tags) -> (batch_sz, num_tags), (batch_sz, num_tags)
next_score, idx = torch.max(next_score, dim=1)
score = torch.where(mask[t], next_score, score)
history.append(idx)
# (batch_sz, num_tags)
score += self.end_transitions
# Sequence end position with shape (batch_sz,)
seq_end_idx = mask.sum(dim=0) - 1
best_tag_list = []
# compute best path for each batch
for i in range(batch_sz):
_, best_last_idx = torch.max(score[i], dim=0)
best_tags = [best_last_idx.item()]
for idx in history[seq_end_idx[i]::-1]:
best_last_idx = idx[i][best_last_idx].item()
best_tags.append(best_last_idx)
best_tag_list.append(best_tags)
return best_tag_list