BLEU代码实现

最近学习Seq2Seq内容,碰到模型评价指标计算,编写一下BLEU加深印象,指标介绍参见

详解机器翻译任务中的BLEU_Iareges的博客-CSDN博客

使用DP优化计算两个序列相同n元语法的部分

def max_common_seq(s1, s2, m, n, k):
    # 求出s1, s2, 的k元语法下最多匹配词组数
    m = m - k + 1
    n = n - k + 1
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m):
        for j in range(n):
            if s1[i:i+k] == s2[j:j+k]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    return dp[-1][-1]


def BLEU(label, pred, k):
    #  计算label 与 pred 在最大k元语法下的BLEU
    assert k > 0
    assert len(pred) >= k
    m = len(label)
    n = len(pred)
    score = exp(min(0, 1 - m/n))
    for p in range(1, k+1):
        score *= (max_common_seq(label, pred, m, n, p) / (n - p + 1)) ** (1 / (2 ** p))
    return score

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
Bleu(Bilingual Evaluation Understudy)是一种用于评估机器翻译结果的指标,它是一种基于n-gram的精度指标。Bleu指标的值在0和1之间,值越高表示机器翻译结果与人工翻译结果越接近。下面是一个基于MATLAB实现Bleu代码示例: ```matlab function bleu_score = BLEU(candidate, reference, n) % candidate: 机器翻译结果,字符串形式 % reference: 参考翻译结果,字符串形式或字符串数组形式 % n: n-gram的大小,取值范围为1-4 % 将candidate和reference转换为字符串数组形式 if ischar(reference) reference = {reference}; end % 计算每个n-gram的精度权重 weights = ones(n, 1) ./ n; % 计算candidate和reference的n-gram数值 candidate_ngram_counts = count_ngrams(candidate, n); reference_ngram_counts = count_ngrams(reference, n); % 初始化精度值和翻译长度 precision_scores = zeros(size(candidate_ngram_counts)); candidate_length = length(strsplit(candidate)); % 计算每个n-gram的精度值 for i = 1:length(precision_scores) precision_scores(i) = min(candidate_ngram_counts(i), max(reference_ngram_counts(i))); end % 计算geometric mean geometric_mean = exp(mean(log(precision_scores(precision_scores~=0)))); % 计算Bleu分数 bleu_score = weights' * geometric_mean * exp(min(0, 1 - length(reference)/candidate_length)); end function ngram_counts = count_ngrams(strs, n) % 计算输入字符串数组中每个n-gram的数量 ngram_counts = zeros(1, n^2); for i = 1:length(strs) words = strsplit(strs{i}); for j = 1:length(words)-n+1 ngram = strjoin(words(j:j+n-1)); index = hash(ngram); ngram_counts(index) = ngram_counts(index) + 1; end end end function index = hash(str) % 将字符串哈希为唯一索引 prime = 5; index = 0; for i = 1:length(str) index = index + double(str(i)) * prime^(i-1); end end ``` 上面的代码中,count_ngrams函数用于计算输入字符串数组中每个n-gram的数量,hash函数用于将字符串哈希为唯一索引。BLEU函数是计算Bleu分数的主要函数,它使用了precision_scores数组来存储每个n-gram的精度值,并使用geometric_mean计算geometric mean。最后,它根据输入的参考翻译结果和机器翻译结果计算Bleu分数。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值