机器学习算法(六)- 条件随机场

1. CRF的定义

条件随机场是给定随机变量 X X X 条件下、随机变量 Y Y Y 的马尔可夫随机场,也即随机变量 Y Y Y 构成一个由无向图 G = ( V , E ) G=(V,E) G=(V,E) 表示的马尔可夫随机场。

即满足 P ( Y v ∣ X , Y w , w ≠ v ) = P ( Y v ∣ X , Y w , w ∼ v ) P(Y_v|X,Y_w, w \neq v) = P(Y_v|X,Y_w, w \sim v) P(YvX,Yw,w=v)=P(YvX,Yw,wv) 对任意节点 v v v 成立,则称条件概率分布 P ( Y ∣ X ) P(Y|X) P(YX) 为条件随机场。

式中,

  • w ∼ v w \sim v wv 表示与节点 v v v 有边连接的所有节点 w w w
  • w ≠ v w \neq v w=v 表示节点 v v v 以外的所有节点

线性链CRF

在一般CRF的定义中,并没有要求 X 、 Y X、Y XY 具有相同的结构,但在解决序列标注、命名体识别等任务时,一般假设 X 、 Y X、Y XY 有相同的图结构。

在这种情况下, P ( Y ∣ X ) P(Y|X) P(YX) 构成条件随机场,满足马尔科夫性
P ( Y i ∣ X , Y 1 , ⋯   , Y i − 1 , Y i + 1 , ⋯   , Y n ) = P ( Y i ∣ X , Y i − 1 , Y i + 1 ) P(Y_i|X,Y_1,\cdots,Y_{i-1},Y_{i+1},\cdots,Y_n) = P(Y_i|X,Y_{i-1},Y_{i+1}) P(YiX,Y1,,Yi1,Yi+1,,Yn)=P(YiX,Yi1,Yi+1)

在实际任务中,

  • X X X 表示观测序列, Y Y Y 表示对应的输出序列或状态序列
  • 训练时,利用通过正则化后的极大似然估计学习 P ( Y ∣ X ) P(Y|X) P(YX)
  • 预测时,对于给定的输入序列 x x x,求出条件概率 P ( y ∣ x ) P(y|x) P(yx) 最大的输出序列 y y y

2. 条件概率的计算

根据线性链条件随机场的定义,给定观测序列为 x x x,输出序列为 y y y 的条件概率为
p ( y ∣ x ) = e x p [ S c o r e ( x , y ) ] ∑ y ′ e x p [ S c o r e ( x , y ′ ) ] p(y|x) = \frac{exp[Score(x,y)]}{\sum_{y'}exp[Score(x,y')]} p(yx)=yexp[Score(x,y)]exp[Score(x,y)]

其中, S c o r e ( x , y ) Score(x,y) Score(x,y) 的定义为
S c o r e ( x , y ) = ∑ i = 1 n ψ i ( x , y ) = ∑ i = 1 n [ t r a n s ( y i − 1 , y i ) + e m i t ( x i , y i ) ] = ∑ i = 1 n [ A y i − 1 , y i + P x i , y i ] \begin{aligned} Score(x,y) &= \sum_{i=1}^{n} \psi_i(x,y) \\ &= \sum_{i=1}^{n} [trans(y_{i-1},y_{i})+emit(x_i,y_i)] \\ &= \sum_{i=1}^{n} [A_{y_{i-1},y_i}+P_{x_i,y_i}] \end{aligned} Score(x,y)=i=1nψi(x,y)=i=1n[trans(yi1,yi)+emit(xi,yi)]=i=1n[Ayi1,yi+Pxi,yi]

势函数 ψ ( x , y ) \psi(x,y) ψ(x,y) 分为两部分,转移函数和发射函数。转移函数表示输出序列中,由上一时刻转化为当前时刻值的概率,发射函数表示当前时刻的观测值对应当前时刻输出值的概率。

在训练过程中,对上方的条件概率取对数,得到
l o g [ p ( y ∣ x ) ] = S c o r e ( x , y ) − l o g [ ∑ y ′ e x p ( S c o r e ( x , y ′ ) ) ] log[p(y|x)]=Score(x,y)-log[\sum_{y'} exp(Score(x,y'))] log[p(yx)]=Score(x,y)log[yexp(Score(x,y))]

上方表达式中第一项的计算比较简单,重点是第二项,这里称其为 l o g   s u m log\ sum log sum,需要遍历所有可能的输出序列对应的分数。

Log-Sum的计算

记序列 y y y 总共有 T T T 种状态,截止到时刻 i i i 序列的分数为 S i ( x , y ) S_i(x,y) Si(x,y),截至到该时刻、最后一个状态 y i = s j y_i=s_j yi=sj 的log sum总分数是 f ( i , s j ) f(i,s_j) f(i,sj)

于是根据公式可得:
f ( i , s j ) = l o g ∑ y ′ ∈ Y i − 1 , y i = s j e x p [ S i ( x , y ′ , s j ) ] = l o g ∑ t = 1 T ∑ y ′ ∈ Y i − 2 , y i − 1 = s t , y i = s j e x p [ S i ( x , y ′ , s t , s j ) ] = l o g ∑ t = 1 T ∑ y ′ ∈ Y i − 2 , y i − 1 = s t e x p [ S i − 1 ( x , y ′ , s t ) + A s t , s j + P x i , s j ] = l o g ∑ t = 1 T e x p [ A s t , s j + P x i , s j ] ∗ ∑ y ′ ∈ Y i − 2 , y i − 1 = s t e x p [ S i − 1 ( x , y ′ , s t ) ] = l o g ∑ t = 1 T e x p [ A s t , s j + P x i , s j ] ∗ e x p ( l o g ∑ y ′ ∈ Y i − 2 , y i − 1 = s t e x p [ S i − 1 ( x , y ′ , s t ) ) ] = l o g ∑ t = 1 T e x p [ A s t , s j + P x i , s j ] ∗ e x p [ f ( i − 1 , s t ) ] = l o g ∑ t = 1 T e x p [ f ( i − 1 , s t ) + A s t , s j + P x i , s j ] \begin{aligned} f(i,s_j) &= log \sum_{y'\in Y_{i-1}, y_i=s_j} exp[S_i(x,y',s_j)] \\ &= log \sum_{t=1}^T \sum_{y'\in Y_{i-2}, y_{i-1}=s_t,y_i=s_j} exp[S_i(x,y',s_t,s_j)] \\ &= log \sum_{t=1}^T \sum_{y'\in Y_{i-2},y_{i-1}=s_t} exp[S_{i-1}(x,y',s_t)+A_{s_t,s_j}+P_{x_i,s_j}] \\ &= log \sum_{t=1}^T exp[A_{s_t,s_j}+P_{x_i,s_j}] * \sum_{y'\in Y_{i-2},y_{i-1}=s_t} exp[S_{i-1}(x,y',s_t)] \\ &= log \sum_{t=1}^T exp[A_{s_t,s_j}+P_{x_i,s_j}] * exp(log\sum_{y'\in Y_{i-2},y_{i-1}=s_t} exp[S_{i-1}(x,y',s_t))] \\ &= log \sum_{t=1}^T exp[A_{s_t,s_j}+P_{x_i,s_j}] * exp[f(i-1,s_t)] \\ &= log \sum_{t=1}^T exp[f(i-1,s_t)+A_{s_t,s_j}+P_{x_i,s_j}] \end{aligned} f(i,sj)=logyYi1,yi=sjexp[Si(x,y,sj)]=logt=1TyYi2,yi1=st,yi=sjexp[Si(x,y,st,sj)]=logt=1TyYi2,yi1=stexp[Si1(x,y,st)+Ast,sj+Pxi,sj]=logt=1Texp[Ast,sj+Pxi,sj]yYi2,yi1=stexp[Si1(x,y,st)]=logt=1Texp[Ast,sj+Pxi,sj]exp(logyYi2,yi1=stexp[Si1(x,y,st))]=logt=1Texp[Ast,sj+Pxi,sj]exp[f(i1,st)]=logt=1Texp[f(i1,st)+Ast,sj+Pxi,sj]

因此,当前时刻的log sum值可通过前一时刻的log sum值得到。

3. 代码实现

结合上述公式,基于Pytorch框架实现CRF类,并实现向前传播与解码过程。

import torch
import torch.nn as nn

from typing import Optional, List


class CRF(nn.Module):
    """
    This module implements a conditional random field.
    The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor.
    This class also has 'decode' method which finds the best tag sequence given an emission score tensor using 'Viterbi algorithm'.
    """

    def __init__(self, num_tags: int, batch_first: bool = False):
        """
        Args:
        :param num_tags: Number of tags.
        :param batch_first: Whether the first dimension corresponds to the size of a minibatch.
        Attributes:
            start_transitions: Start transition score tensor of size (num_tags, ).
            end_transitions: End transition score tensor of size (num_tags, ).
            transitions: Transition matrix tensor of size (num_tags, num_tags).
        """
        if num_tags <= 0:
            raise ValueError(f'num_tags must be positive.')
        super(CRF, self).__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
        # initialize transition parameters
        self.reset_parameters()

    def __repr__(self):
        # Description of the class.
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def reset_parameters(self):
        """
        Initialize the transition paramters.
        The parameters will be initialized randomly from a uniform distribution between -0.1 and 0.1
        """
        nn.init.uniform_(self.start_transitions)
        nn.init.uniform_(self.end_transitions)
        nn.init.uniform_(self.transitions)

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum'
    ) -> torch.Tensor:
        """
        Args:
        :param emissions: Emission score tensor of size (batch_sz, seq_len, num_tags) if batch_first else (seq_len, batch_sz, num_tags)
        :param tags: Sequence of tags tensor of size (batch_sz, seq_len) if batch_first else (seq_len, batch_sz)
        :param mask: Mask tensor of size (batch_sz, seq_len) if batch_first else (seq_len, batch_sz)
        :param reduction: Specifies the reduction from 'none/sum/mean/token_mean' to apply the output.
        Returns:
        :return The log-likelyhood score of shape (batch_size, ) if reduction='none' else ().
        """
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'Invalid Reduction: {reduction}')

        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # (batch_sz, )
        numerator = self._compute_score(emissions, tags, mask)
        # (batch_sz, )
        denominator = self._compute_log_sum_score(emissions, mask)
        # (batch_sz, )
        llh = numerator - denominator  # log-likelyhood

        if reduction == 'none':  # not apply reduction
            return llh
        if reduction == 'sum':  # summed over batches
            return llh.sum()
        if reduction == 'mean':  # average over batches
            return llh.mean()
        if reduction == 'token_mean':  # average over tokens
            return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """
        Find the most likely tag sequence using Viterbi algorithm
        Args:
        :param emissions: (seq_len, batch_size, num_tags)
        :param mask: (seq_len, batch_size)
        Return:
        :return List of list containing the best tag sequence for each batch.
        """
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)
        return self._viterbi_decode(emissions, mask)

    def _compute_score(self, emissions, tags, mask):
        """
        :param emissions: (seq_len, batch_size, num_tags)
        :param tags: (seq_len, batch_size)
        :param mask: (seq_len, batch_size)
        """
        seq_len = tags.shape[0]
        mask = mask.float()
        # Start transition and emission score
        # (batch_sz,)
        score = self.start_transitions[tags[0]] + emissions[0, :, tags[0]]
        for t in range(1, seq_len):
            # Transition score to next tag
            score += self.transitions[tags[t - 1], tags[t]] * mask[t]
            # Emission score for next tag
            score += emissions[t, :, tags[t]] * mask[t]
        # Sequence end position with shape (batch_sz,)
        seq_end_idx = mask.sum(dim=0) - 1
        score += self.end_transitions[tags[seq_end_idx]]
        # (batch_sz,)
        return score

    def _compute_log_sum_score(self, emissions, mask):
        """
        :param emissions: (seq_len, batch_size, num_tags)
        :param mask: (seq_len, batch_size)
        """
        seq_len = mask[0]
        mask = mask.float()
        # (batch_sz, num_tags)
        score = self.start_transitions + emissions[0]
        # compute sum of all scores end with tag i to tag j till each timestep.
        for t in range(1, seq_len):
            # (batch_sz, num_tags) -> (batch_sz, num_tags, 1)
            broad_score = score.unsqueeze(2)
            # (batch_sz, num_tags) -> (batch_sz, 1, num_tags)
            broad_emit = emissions[t].unsqueeze(1)
            # f(i, s_j) = log-sum-exp[f(i-1,s_t)+trans(s_t,s_j)+emit(s_j)]
            # (batch_sz, num_tags, num_tags)
            next_score = broad_score + self.transitions.unsqueeze(0) + broad_emit
            # (batch_sz, num_tags, num_tags) -> (batch_sz, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)
            score = torch.where(mask[t], next_score, score)
        # (batch_sz, num_tags)
        score += self.end_transitions
        # (batch_sz, num_tags) -> (batch_sz,)
        score = torch.logsumexp(score, 1)
        return score

    def _viterbi_decode(self, emissions, mask) -> List[List[int]]:
        """
        :param emissions: (seq_len, batch_size, num_tags)
        :param mask: (seq_len, batch_size)
        :return List of list containing the best tag sequence for each batch.
        """
        seq_len, batch_sz = mask.shape
        history = []
        # (batch_sz, num_tags)
        score = self.start_transitions + emissions[0]
        # compute score of best sequence end with tag i to tag j till each timestep.
        for t in range(1, seq_len):
            # (batch_sz, num_tags) -> (batch_sz, num_tags, 1)
            broad_score = score.unsqueeze(2)
            # (batch_sz, num_tags) -> (batch_sz, 1, num_tags)
            broad_emit = emissions[t].unsqueeze(1)
            # (batch_sz, num_tags, num_tags)
            next_score = broad_score + self.transitions.unsqueeze(0) + broad_emit
            # (batch_sz, num_tags, num_tags) -> (batch_sz, num_tags), (batch_sz, num_tags)
            next_score, idx = torch.max(next_score, dim=1)
            score = torch.where(mask[t], next_score, score)
            history.append(idx)
        # (batch_sz, num_tags)
        score += self.end_transitions
        # Sequence end position with shape (batch_sz,)
        seq_end_idx = mask.sum(dim=0) - 1
        best_tag_list = []
        # compute best path for each batch
        for i in range(batch_sz):
            _, best_last_idx = torch.max(score[i], dim=0)
            best_tags = [best_last_idx.item()]
            for idx in history[seq_end_idx[i]::-1]:
                best_last_idx = idx[i][best_last_idx].item()
                best_tags.append(best_last_idx)
            best_tag_list.append(best_tags)
        return best_tag_list

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Anycall201

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值