grid-crf是在pystruct里面的一个针对网格图的crf,但是官方代码只能运用在python3.6
且很难将其融入深度学习当中。
网上几乎没有相关资料,没有办法,我经过不断调试源码,在这里对其进行总结
X, Y = generate_crosses_explicit(n_samples=10, noise=12)
#使用内置的数据集生成函数自动生成
crf = GridCRF(neighborhood=8)
#声明grid-crf类,但是这里只是声明了一个结构而已,里面大部分的内容是基于下面OneSlackSSVM构建的
#neighborhood=8表示网格上面的每个点的下一个点是它周围的八个节点。
clf = ssvm.OneSlackSSVM(model=crf, C=100, inference_cache=100,
tol=.1)
#这里才是声明了一个完整的grid-crf
clf.fit(X, Y)
#这里有迭代和训练grid-crf
主要函数在one_slack_ssvm
joint_feature_gt = self.model.batch_joint_feature(X, Y)
#X(10,9,9,3) Y(10,9,9) joint_feature_gt (15)
joint_feature_gt为15=9+6
Y_hat, djoint_feature, loss_mean = self._find_new_constraint(
X, Y, joint_feature_gt, constraints)
#Y_hat (10,9,9,3), djoint_feature(15), loss_mean(float64),constraints(15)
def _update_cache(self, X, Y, Y_hat):
#Y_hat 状态特征10,9,9,3 转移特征10,272,9
"""Updated cached constraints."""
if self.inference_cache == 0:
return
if (not hasattr(self, "inference_cache_")
or self.inference_cache_ is None):
self.inference_cache_ = [[] for y in Y_hat]
for sample, x, y, y_hat in zip(self.inference_cache_, X, Y, Y_hat):
already_there = [self.constraint_equal(y_hat, cache[2])
for cache in sample]
if np.any(already_there):
continue
if len(sample) > self.inference_cache:
sample.pop(0)
# we computed both of these before, but summed them up immediately
# this makes it a little less efficient in the caching case.
# the idea is that if we cache, inference is way more expensive
# and this doesn't matter much.
sample.append((self.model.joint_feature(x, y_hat),
self.model.loss(y, y_hat), y_hat))