import torch
import torch.nn as nn
from typing import List, Optional
classCRF(nn.Module):"""Conditional random field.
This module implements a conditional random field [LMP01]_. The forward computation
of this class computes the log likelihood of the given sequence of tags and
emission score tensor. This class also has `~CRF.decode` method which finds
the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
Args:
num_tags: Number of tags.
batch_first: Whether the first dimension corresponds to the size of a minibatch.
Attributes:
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
``(num_tags,)``.
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
``(num_tags,)``.
transitions (`~torch.nn.Parameter`): Transition score tensor of size
``(num_tags, num_tags)``.
.. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
"Conditional random fields: Probabilistic models for segmenting and
labeling sequence data". *Proc. 18th International Conf. on Machine
Learning*. Morgan Kaufmann. pp. 282–289.
.. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
"""def__init__(self, num_tags:int, batch_first:bool=False)->None:if num_tags <=0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first
self.start_transitions = nn.Parameter(torch.empty(num_tags))
self.end_transitions = nn.Parameter(torch.empty(num_tags))
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
self.reset_parameters()defreset_parameters(self)->None:"""Initialize the transition parameters.
The parameters will be initialized randomly from a uniform distribution
between -0.1 and 0.1.
"""
nn.init.uniform_(self.start_transitions,-0.1,0.1)
nn.init.uniform_(self.end_transitions,-0.1,0.1)
nn.init.uniform_(self.transitions,-0.1,0.1)def__repr__(self)->str:returnf'{self.__class__.__name__}(num_tags={self.num_tags})'defforward(
self,
emissions: torch.Tensor,
tags: torch.LongTensor,
mask: Optional[torch.ByteTensor]=None,
reduction:str='sum',)-> torch.Tensor:"""Compute the conditional log likelihood of a sequence of tags given emission scores.
Args:
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
tags (`~torch.LongTensor`): Sequence of tags tensor of size
``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
reduction: Specifies the reduction to apply to the output:
``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
``sum``: the output will be summed over batches. ``mean``: the output will be
averaged over batches. ``token_mean``: the output will be averaged over tokens.
Returns:
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
reduction is ``none``, ``()`` otherwise.
"""
self._validate(emissions, tags=tags, mask=mask)if reduction notin('none','sum','mean','token_mean'):raise ValueError(f'invalid reduction: {reduction}')if mask isNone:
mask = torch.ones_like(tags, dtype=torch.uint8)if self.batch_first:
emissions = emissions.transpose(0,1)
tags = tags.transpose(0,1)
mask = mask.transpose(0,1)# shape: (batch_size,)
numerator = self._compute_score(emissions, tags, mask)# shape: (batch_size,)
denominator = self._compute_normalizer(emissions, mask)# shape: (batch_size,)
llh = numerator - denominator
if reduction =='none':return llh
if reduction =='sum':return llh.sum()if reduction =='mean':return llh.mean()assert reduction =='token_mean'return llh.sum()/ mask.float().sum()defdecode(self, emissions: torch.Tensor,
mask: Optional[torch.ByteTensor]=None)-> List[List[int]]:"""Find the most likely tag sequence using Viterbi algorithm.
Args:
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
Returns:
List of list containing the best tag sequence for each batch.
"""
self._validate(emissions, mask=mask)if mask isNone:
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)if self.batch_first:
emissions = emissions.transpose(0,1)
mask = mask.transpose(0,1)return self._viterbi_decode(emissions, mask)def_validate(
self,
emissions: torch.Tensor,
tags: Optional[torch.LongTensor]=None,
mask: Optional[torch.ByteTensor]=None)->None:if emissions.dim()!=3:raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')if emissions.size(2)!= self.num_tags:raise ValueError(f'expected last dimension of emissions is {self.num_tags}, 'f'got {emissions.size(2)}')if tags isnotNone:if emissions.shape[:2]!= tags.shape:raise ValueError('the first two dimensions of emissions and tags must match, 'f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')if mask isnotNone:if emissions.shape[:2]!= mask.shape:raise ValueError('the first two dimensions of emissions and mask must match, 'f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
no_empty_seq =not self.batch_first and mask[0].all()
no_empty_seq_bf = self.batch_first and mask[:,0].all()ifnot no_empty_seq andnot no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on')def_compute_score(
self, emissions: torch.Tensor, tags: torch.LongTensor,
mask: torch.ByteTensor)-> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# tags: (seq_length, batch_size)# mask: (seq_length, batch_size)assert emissions.dim()==3and tags.dim()==2assert emissions.shape[:2]== tags.shape
assert emissions.size(2)== self.num_tags
assert mask.shape == tags.shape
assert mask[0].all()
seq_length, batch_size = tags.shape
mask = mask.float()# Start transition score and first emission# shape: (batch_size,)
score = self.start_transitions[tags[0]]
score += emissions[0, torch.arange(batch_size), tags[0]]for i inrange(1, seq_length):# Transition score to next tag, only added if next timestep is valid (mask == 1)# shape: (batch_size,)
score += self.transitions[tags[i -1], tags[i]]* mask[i]# Emission score for next tag, only added if next timestep is valid (mask == 1)# shape: (batch_size,)
score += emissions[i, torch.arange(batch_size), tags[i]]* mask[i]# End transition score# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0)-1# shape: (batch_size,)
last_tags = tags[seq_ends, torch.arange(batch_size)]# shape: (batch_size,)
score += self.end_transitions[last_tags]return score
def_compute_normalizer(
self, emissions: torch.Tensor, mask: torch.ByteTensor)-> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)assert emissions.dim()==3and mask.dim()==2assert emissions.shape[:2]== mask.shape
assert emissions.size(2)== self.num_tags
assert mask[0].all()
seq_length = emissions.size(0)# Start transition score and first emission; score has size of# (batch_size, num_tags) where for each batch, the j-th column stores# the score that the first timestep has tag j# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]for i inrange(1, seq_length):# Broadcast score for every possible next tag# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)# Broadcast emission score for every possible current tag# shape: (batch_size, 1, num_tags)
broadcast_emissions = emissions[i].unsqueeze(1)# Compute the score tensor of size (batch_size, num_tags, num_tags) where# for each sample, entry at row i and column j stores the sum of scores of all# possible tag sequences so far that end with transitioning from tag i to tag j# and emitting# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emissions
# Sum over all possible current tags, but we're in score space, so a sum# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of# all possible tag sequences so far, that end in tag i# shape: (batch_size, num_tags)
next_score = torch.logsumexp(next_score, dim=1)# Set score to the next score if this timestep is valid (mask == 1)# shape: (batch_size, num_tags)
score = torch.where(mask[i].unsqueeze(1), next_score, score)# End transition score# shape: (batch_size, num_tags)
score += self.end_transitions
# Sum (log-sum-exp) over all possible tags# shape: (batch_size,)return torch.logsumexp(score, dim=1)def_viterbi_decode(self, emissions: torch.FloatTensor,
mask: torch.ByteTensor)-> List[List[int]]:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)assert emissions.dim()==3and mask.dim()==2assert emissions.shape[:2]== mask.shape
assert emissions.size(2)== self.num_tags
assert mask[0].all()
seq_length, batch_size = mask.shape
# Start transition and first emission# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
history =[]# score is a tensor of size (batch_size, num_tags) where for every batch,# value at column j stores the score of the best tag sequence so far that ends# with tag j# history saves where the best tags candidate transitioned from; this is used# when we trace back the best tag sequence# Viterbi algorithm recursive case: we compute the score of the best tag sequence# for every possible next tagfor i inrange(1, seq_length):# Broadcast viterbi score for every possible next tag# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)# Broadcast emission score for every possible current tag# shape: (batch_size, 1, num_tags)
broadcast_emission = emissions[i].unsqueeze(1)# Compute the score tensor of size (batch_size, num_tags, num_tags) where# for each sample, entry at row i and column j stores the score of the best# tag sequence so far that ends with transitioning from tag i to tag j and emitting# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emission
# Find the maximum score over all possible current tag# shape: (batch_size, num_tags)
next_score, indices = next_score.max(dim=1)# Set score to the next score if this timestep is valid (mask == 1)# and save the index that produces the next score# shape: (batch_size, num_tags)
score = torch.where(mask[i].unsqueeze(1), next_score, score)
history.append(indices)# End transition score# shape: (batch_size, num_tags)
score += self.end_transitions
# Now, compute the best path for each sample# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0)-1
best_tags_list =[]for idx inrange(batch_size):# Find the tag which maximizes the score at the last timestep; this is our best tag# for the last timestep
_, best_last_tag = score[idx].max(dim=0)
best_tags =[best_last_tag.item()]# We trace back where the best last tag comes from, append that to our best tag# sequence, and trace it back again, and so onfor hist inreversed(history[:seq_ends[idx]]):
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(best_last_tag.item())# Reverse the order because we start from the last timestep
best_tags.reverse()
best_tags_list.append(best_tags)return best_tags_list