Tensorflow2.+版本使用CRF
最近因为GPU原因从tf1.x转到tf2.x,许多方法都用不了了,其中就有CRF,特此记录一下
之前CRF一直都存在于tf1.x中的contrib中,不过tf2.x连contrib都没了,CRF也转到了tensorflow_addons下
调用方式
import tensorflow_addons或者pip list查看一下
没有的话就pip安装即可:
pip install tensorflow_addons
安装好了直接调用即可:
import tensorflow as tf
import tensorflow_addons as tfa
# log_likelikelihood,transition_params = tf.contrib.crf.crf_log_likehood(input,tags,seqlen)
# predict, viterbi_score = tf.contrib.crf.crf_decode(input, transition_params, seqlen)
log_likelikelihood,transition_params = tf.text.crf.crf_log_likehood(input,tags,seqlen)
predict, viterbi_score = tfa.text.crf.crf_decode(input,transition_params,seqlen)
现在应该可以成功调用了
此外,如果还有其他方法在tf2中没法用了,可以尝试tf.compat.v1的方式来调用该方法
例如:
tf2中没法调用tf.gfile.Exists()方法,但是可以通过:tf.compat.v1.gfile.Exists()来使用
tf.compat.v1模块是确保高版本的TF支持低版本的TF的,还是建议转换为TF2.0中的新方法:
tf.gfile.Exists()方法在TF2.0中为: tf.io.gfile.exists()