CRF的三个核心函数

https://www.jianshu.com/p/bddf0641970c

CRF的实现-tensorflow版本
不分享的知识毫无意义
0.076
2020.03.14 17:33:26
字数 1,039阅读 588
0.前言

CRF的原理已经够难理解了,需要解决的问题主要包括三大块:

概率计算问题,前向—后向算法,是一个递推公式,这个和hmm是一样的。
学习问题,这是判别式模型必须要有的东西,得训练参数,常用的方法是改进的迭代尺度法,拟牛顿法。
预测问题,维特比算法,这是个动态规划方法,hmm和crf都会用到。这个好像废话,目的都是为了预测,当然要用。
数学公式一大堆,什么向量形式,矩阵形式,着实难以理解,但是关于事先就很简单了,哈哈哈。下面分别基于tensorflow、keras、pytorch来实现CRF。

1.tensorflow实现

tensorflow1.0可真难用啊,吐槽一下,还是2.0好用。举个小例子,你定义一个op操作以后,即使是简单的x1+x2,要想看输出,还得print(sess.run()),2.0就不用了,大家赶紧上手2.0。不过这里还是基于tensorflow1.0实现的。
tensorflow实现crf就三个函数,crf_log_likelihood、viterbi_decode、crf_decode,他们都在tf.contrib.crf这个API里,搞懂这三个函数,不管事BiLSTM+CRF还是BERT+BiLSTM+CRF你都游刃有余了。

tf.contrib.crf.crf_log_likelihood
crf_log_likelihood(inputs,tag_indices,sequence_lengths,transition_params=None)
通俗理解,这是CRF的训练函数。
首先来看输入:
(1)inputs,维度为[batch_size, max_seq_len, num_tags],一般是LSTM的输出,要转换成这个要求的维度,再到CRF里边训练。
batch_size是批次训练样本量,好理解,不解释。
maxseq_len是输入文本的长度,相当于LSTM里的input_dim,就是输入几个单词。
num_tags是可供选择的单词个数,比如你觉得这个位置有5个可能的单词,那这个就是5。
(2)tag_indices,维度为[batch_size, max_seq_len]。
具体的和inputs一样,只不过这个是真实的标签,也就是相应位置对应的真实y值。
(3)sequence_lengths,维度为 [batch_size]。
表示的是每一个序列的长度,是一维的,相当于max_sql_len,可以用np.full这个函数实现。
(4)transition_params,维度为[num_tags, num_tags],是转移矩阵,要是事先没有就训练一个。
然后来看输出:
(1)log_likelihood,标量,还记得吧,CRF训练参数用的是极大似然估计,这个值取负数就是交叉熵损失。
(2)transition_params,维度为[num_tags, num_tags],转移矩阵,这个是我们预测要用到的。
tf.contrib.crf.viterbi_decode
viterbi_decode(score,transition_params)
这个函数返回最好序列的标签,用的场景不是特别多。
输入:
(1)score,维度为[seq_len, num_tags],参数的意思就不解释了,具体看上边的说法,这就是一个得分。
(2)transition_params,维度为[num_tags, num_tags],上边训练输出的转移矩阵。
输出:
(1)viterbi,维度[seq_len],保留了每一步对应得分值最高的索引。
(2)viterbi_score,维度为[sel_len],这个是维特比的具体得分。
tf.contrib.crf.viterbi_decode
crf_decode(potentials,transition_params,sequence_length)
这个函数和上边那个差不多,但是很常用。
输入:
(1)potentials,维度为[batch_size, max_seq_len, num_tags],这个是满足条件的一个输入,可以使输入和一个权重矩阵乘后的结果。
(2)transition_params,转义矩阵不多说。
(3)sequence_length,和上边一样,输入长度构成的一维矩阵。
输出:
(1)decode_tags,维度为[batch_size, max_seq_len] ,是一个最好序列的标记。
(2)best_score,维度为[batch_size],每个序列的最好得分。
来看一个小例子,这个例子是一个随机的数字输入,对应一个只含0,1两个状态的目标矩阵,然后根据输入预测输出。代码如下:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
Timestep = 15#输入的总长度,可以理解为15个rnn cell
Batchsize = 1#一次就输入一个
Inputsize = 1
LR = 0.5
num_tags = 2
#定义batch输出
def get_batch():
    xs = np.array([[2, 3, 4, 5, 5, 5, 1, 5, 3, 2, 5, 5, 5, 3, 5]])
    res = np.array([[0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1]])
    return [xs[:, :, np.newaxis], res]
# xs, res = get_batch()
# print(xs)
# xs变成三维的 res还是二维的
class crf:
    def __init__(self, time_steps, input_size, num_tags, batch_size):
        self.time_steps = time_steps
        self.input_size = input_size
        self.num_tags = num_tags
        self.batch_size = batch_size
        self.xs = tf.placeholder(tf.float32, [None, self.time_steps, self.input_size], name='xs')
        self.res = tf.placeholder(tf.int32, [self.batch_size, self.time_steps], name='res')#为什么和xs的定义模式不一样
        weights = tf.get_variable('weights', [self.input_size, self.num_tags])
        matricized_xs = tf.reshape(self.xs, [-1, self.input_size])
        matricized_unary_scores = tf.matmul(matricized_xs, weights)
        unary_scores = tf.reshape(matricized_unary_scores, [self.batch_size, self.time_steps, self.num_tags])
        sequence_len = np.full(self.batch_size, self.time_steps, dtype=np.int32)
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(unary_scores, self.res, sequence_len)
        self.pred, viterbiscore = tf.contrib.crf.crf_decode(unary_scores, transition_params, sequence_len)
        self.loss = tf.reduce_mean(-log_likelihood)
        self.train_op = tf.train.AdamOptimizer(LR).minimize(self.loss)


if __name__ == '__main__':
    model = crf(Timestep, Inputsize, num_tags, Batchsize)
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    plt.ion()#动态曲线
    plt.show()
    for i in range(150):
        xs, res = get_batch()
        feed_dict = {model.xs: xs,
                     model.res: res}
        _, cost, pred = sess.run([model.train_op, model.loss, model.pred],
                                 feed_dict=feed_dict)#只有placeholder才可以feed
        x = xs.reshape(-1, 1)
        r = res.reshape(-1, 1)
        p = pred.reshape(-1, 1)
        x = range(len(x))
        plt.clf()
        plt.plot(x, r, 'r', x, p, 'g')
        plt.ylim(-1.2, 1.2)
        plt.draw()
        plt.pause(0.3)
        if i % 20 == 0:
            print('cost:', round(cost, 4))
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在现有省、市港口信息化系统进行有效整合基础上,借鉴新 一代的感知-传输-应用技术体系,实现对码头、船舶、货物、重 大危险源、危险货物装卸过程、航管航运等管理要素的全面感知、 有效传输和按需定制服务,为行政管理人员和相关单位及人员提 供高效的管理辅助,并为公众提供便捷、实时的水运信息服务。 建立信息整合、交换和共享机制,建立健全信息化管理支撑 体系,以及相关标准规范和安全保障体系;按照“绿色循环低碳” 交通的要求,搭建高效、弹性、高可扩展性的基于虚拟技术的信 息基础设施,支撑信息平台低成本运行,实现电子政务建设和服务模式的转变。 实现以感知港口、感知船舶、感知货物为手段,以港航智能 分析、科学决策、高效服务为目的和核心理念,构建“智慧港口”的发展体系。 结合“智慧港口”相关业务工作特点及信息化现状的实际情况,本项目具体建设目标为: 一张图(即GIS 地理信息服务平台) 在建设岸线、港口、港区、码头、泊位等港口主要基础资源图层上,建设GIS 地理信息服务平台,在此基础上依次接入和叠加规划建设、经营、安全、航管等相关业务应用专题数据,并叠 加动态数据,如 AIS/GPS/移动平台数据,逐步建成航运管理处 "一张图"。系统支持扩展框架,方便未来更多应用资源的逐步整合。 现场执法监管系统 基于港口(航管)执法基地建设规划,依托统一的执法区域 管理和数字化监控平台,通过加强对辖区内的监控,结合移动平 台,形成完整的多维路径和信息追踪,真正做到问题能发现、事态能控制、突发问题能解决。 运行监测和辅助决策系统 对区域港口与航运业务日常所需填报及监测的数据经过科 学归纳及分析,采用统一平台,消除重复的填报数据,进行企业 输入和自动录入,并进行系统智能判断,避免填入错误的数据, 输入的数据经过智能组合,自动生成各业务部门所需的数据报 表,包括字段、格式,都可以根据需要进行定制,同时满足扩展 性需要,当有新的业务监测数据表需要产生时,系统将分析新的 需求,将所需字段融合进入日常监测和决策辅助平台的统一平台中,并生成新的所需业务数据监测及决策表。 综合指挥调度系统 建设以港航应急指挥中心为枢纽,以各级管理部门和经营港 口企业为节点,快速调度、信息共享的通信网络,满足应急处置中所需要的信息采集、指挥调度和过程监控等通信保障任务。 设计思路 根据项目的建设目标和“智慧港口”信息化平台的总体框架、 设计思路、建设内容及保障措施,围绕业务协同、信息共享,充 分考虑各航运(港政)管理处内部管理的需求,平台采用“全面 整合、重点补充、突出共享、逐步完善”策略,加强重点区域或 运输通道交通基础设施、运载装备、运行环境的监测监控,完善 运行协调、应急处置通信手段,促进跨区域、跨部门信息共享和业务协同。 以“统筹协调、综合监管”为目标,以提供综合、动态、实 时、准确、实用的安全畅通和应急数据共享为核心,围绕“保畅通、抓安全、促应急"等实际需求来建设智慧港口信息化平台。 系统充分整合和利用航运管理处现有相关信息资源,以地理 信息技术、网络视频技术、互联网技术、移动通信技术、云计算 技术为支撑,结合航运管理处专网与行业数据交换平台,构建航 运管理处与各部门之间智慧、畅通、安全、高效、绿色低碳的智 慧港口信息化平台。 系统充分考虑航运管理处安全法规及安全职责今后的变化 与发展趋势,应用目前主流的、成熟的应用技术,内联外引,优势互补,使系统建设具备良好的开放性、扩展性、可维护性。
提供的源码资源涵盖了安卓应用、小程序、Python应用和Java应用等多个领域,每个领域都包含了丰富的实例和项目。这些源码都是基于各自平台的最新技术和标准编写,确保了在对应环境下能够无缝运行。同时,源码中配备了详细的注释和文档,帮助用户快速理解代码结构和实现逻辑。 适用人群: 这些源码资源特别适合大学生群体。无论你是计算机相关专业的学生,还是对其他领域编程感兴趣的学生,这些资源都能为你提供宝贵的学习和实践机会。通过学习和运行这些源码,你可以掌握各平台开发的基础知识,提升编程能力和项目实战经验。 使用场景及目标: 在学习阶段,你可以利用这些源码资源进行课程实践、课外项目或毕业设计。通过分析和运行源码,你将深入了解各平台开发的技术细节和最佳实践,逐步培养起自己的项目开发和问题解决能力。此外,在求职或创业过程中,具备跨平台开发能力的大学生将更具竞争力。 其他说明: 为了确保源码资源的可运行性和易用性,特别注意了以下几点:首先,每份源码都提供了详细的运行环境和依赖说明,确保用户能够轻松搭建起开发环境;其次,源码中的注释和文档都非常完善,方便用户快速上手和理解代码;最后,我会定期更新这些源码资源,以适应各平台技术的最新发展和市场需求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值