shap 中 PartitionExplainer 原理解读与官方demo的调试笔记

一. PartitionExplainer

shap含多种解释器, 搭配 transformer NLP 任务的, 是 Partition explainer.
它的思想: 通过层次特征, 递归地计算Shapley值,这个层次特征定义了特征联盟(feature coalitions),以此从博弈论中得到Owen值。

源码涉及到太多的类与条件分支, 可读性差, 这里作算法简洁描述.

联盟描述

  • 首先引入 五元组结构体 CoalitionInfo.
CoalitionInfo, (m00, f00, f11, ind, weight)
1. 该联盟未参与时的mask信息
2. 该联盟未参与时的结果
3. 该联盟参与时的结果
4. 该联盟的标识
5. 该联盟的权重

算法描述

1. 根节点放入队列
2. 联盟出队列, 记为 co_i
3. 如果 co_i 是叶子节点, 更新`values[i] += (f11 - f00) * weight`; 否则 分别构造其左右子树两个联盟的模型输入 fin.
4. 将 fin 送入模型, 得到模型输出 fout
5. 根据fout, 维护 左右子树两个联盟的 CoalitionInfo. 注意受上下文context影响, 这里的2个联盟会膨胀为4个CoalitionInfo. 放入队列.
6. 继续步骤2, 直到队列为空.

一个二叉树的节点只有两个子树, 为何会膨胀出2倍的节点, 在于一个联盟的收益计算有两种口径, 见下.
在这里插入图片描述
l_child 的贡献, 两种口径
lift(l_child) = f(A + B + l_child) - f(A + B)
lift(l_child) = f(A + B + l_child + r_child) - f(A + B + r_child)

时间复杂度推算

不妨令层次树为平衡二叉树.
具有n个叶子节点的平衡二叉树, 树高为 h = log ⁡ 2 n h=\log_2n h=log2n
根据上述算法, 要送入模型预测的样本数, 为一个公比为4的等比数列, 根据前n项和公式 S n = a 1 ∗ 1 − q n 1 − q Sn=a_1*\frac{1-q^n}{1-q} Sn=a11q1qn, 树高h就是这里的项数n,
代入得 S h = 1 − 4 log ⁡ 2 n 1 − 4 S_h=\frac{1-4^{\log_2n}}{1-4} Sh=1414log2n, 对分子部分作变换:
( 2 2 ) log ⁡ 2 n = 2 2 log ⁡ 2 n = 2 log ⁡ 2 n 2 = n 2 (2^2)^{\log_2n} = 2^{2{\log_2n}} = 2^{\log_2n^2}=n^2 (22)log2n=22log2n=2log2n2=n2
故时间复杂度为 O ( n 2 ) O(n^2) O(n2).

批处理加速

模型预测, batch方式比单样本模式, 整体速度更快, 所以可对上述算法作工程加速.

1. 为了控制复杂度, 维护 max_eval_cnt, 超过该阈值及时终止, 未计算充分的联盟直接将values分摊至叶子节点.
2. 上述步骤2和3当作整体多次执行, 攒够足够的batch批量送入模型预测.
3. 为了优先照顾贡献明显的联盟, 将队列改为优先队列

二. 搭配 transformer 的官方demo

transformer 用于NLP任务中的 情感分析时, 其实就是 POSITIVE/NEGATIVE 的二分类.
如果我们想对模型的预测作解释, 就需要知道句中不同部分对情感倾向的贡献.
一个官方例子见下:

  • 句子: What a great movie! if you have no taste.
  • 翻译: 多棒的一个电影啊, 如果你没有品位的话.
  • 预处理后: [[‘[CLS]’, ‘what’, ‘a’, ‘great’, ‘movie’, ‘!’, ‘if’, ‘you’, ‘have’, ‘no’, ‘taste’, ‘.’, ‘[SEP]’]]
  • 情感倾向贡献: base_value=-0.690963, 表明整个数据集的平均倾向得分. shap_values=[ 0. , 1.2201888 , 1.2201888 , 3.8936163 , 3.8936163 ,
    0.24495573, -0.16747759, -0.16747759, -0.16747759, -0.16747759,
    -0.16747759, -0.16747759, 0. ], 是个 shape=[13,]的数组, 表明预处理后每个token的情感贡献.
  • 可视化, 见下, 得分相同的被归到了一起.
    在这里插入图片描述
  • 代码
import transformers
import shap
import multiprocessing
import torch

# 没有用, 必须到 shap库中, 把 num_workers 置为0
if __name__ == '__main__':
    torch.multiprocessing.freeze_support()
    multiprocessing.freeze_support()

model_path = r'D:\model_repository\transformer\distilbert-base-uncased-finetuned-sst-2-english'
model = transformers.pipeline('sentiment-analysis', model=model_path, return_all_scores=True)

model_output = model(["What a great movie! ...if you have no taste."])
print('model_output', model_output, '==========', sep='\n')

explainer = shap.Explainer(model)
# Explanation:[1,13,2]
shap_values = explainer(["What a great movie! if you have no taste."])  # type: shap._explanation.Explanation
print('shap_values', shap_values, '==========', sep='\n')
# Explanation: [13,]
positive_shap_values = shap_values[0, :, "POSITIVE"]
print('positive_shap_values', positive_shap_values, '==========', sep='\n')

# visualize the first prediction's explanation for the POSITIVE output class
html_obj = shap.plots.text(positive_shap_values)
print(123)

"""
model_output
[[{'label': 'NEGATIVE', 'score': 0.00014734955038875341}, {'label': 'POSITIVE', 'score': 0.9998526573181152}]]
==========
shap_values
.values =
array([[[ 0.        ,  0.        ],
        [-1.22018296,  1.2201888 ],
        [-1.22018296,  1.2201888 ],
        [-3.89365431,  3.8936163 ],
        [-3.89365431,  3.8936163 ],
        [-0.24507864,  0.24495573],
        [ 0.16742747, -0.16747759],
        [ 0.16742747, -0.16747759],
        [ 0.16742747, -0.16747759],
        [ 0.16742747, -0.16747759],
        [ 0.16742747, -0.16747759],
        [ 0.16742747, -0.16747759],
        [ 0.        ,  0.        ]]])

.base_values =
array([[ 0.69096287, -0.690963  ]])

.data =
array([['', 'What ', 'a ', 'great ', 'movie', '! ', 'if ', 'you ',
        'have ', 'no ', 'taste', '.', '']], dtype='<U6')
==========
positive_shap_values
.values =
array([ 0.        ,  1.2201888 ,  1.2201888 ,  3.8936163 ,  3.8936163 ,
        0.24495573, -0.16747759, -0.16747759, -0.16747759, -0.16747759,
       -0.16747759, -0.16747759,  0.        ])

.base_values =
-0.6909630029188477

.data =
array(['', 'What ', 'a ', 'great ', 'movie', '! ', 'if ', 'you ', 'have ',
       'no ', 'taste', '.', ''], dtype='<U6')
"""

上段代码依赖的代码用到了 from IPython.core.display import display, HTML, 用于html可视化. 但在 非 jupyter 中网页渲染不出来, 只会控制台显示以下信息
Backend Qt5Agg is interactive backend. Turning interactive mode on. <IPython.core.display.HTML object>
Ipython 也不行. 只能断点到 site-packages\shap\plots\_text.py Line: 167 位置把html内的信息复制出来, 新建 html 文件自行打开.

三. shap库中相关类解读

源码解读为 0.39.0, 0.40.0两个版本夹杂.

Explainer

  • shap.explainers._explainer.Explainer#__init__(self, model:TransformersPipeline , masker=None,...)
    构造函数, 初始只需传入 model 参数, 内部识别到它是 transformer(<shap.models._transformers_pipeline.TransformersPipeline object at 0x00000144AFFAD280>), 会再次调用自己, 把 model.inner_model.tokenizer(DistilBertTokenizerFast) 作为 masker 实参传入.
    在这里插入图片描述
    这个构造函数除了自己调自己两次外, 还能做出 self.__class__==指定子类xx的魔幻举动.

  • explain_row(self, *row_arg)

  • shap.explainers._partition.Partition.owen(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent)
    out_shape = 2*(M-1)+1

tokenizer

这里是 PreTrainedTokenizerFast.

PreTrainedTokenizerFast(name_or_path='D:\model_repository\transformers\distilbert-base-uncased-finetuned-sst-2-english', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

masker

  • shap.maskers._text.Text#__init__(self,tokenizer)
    构造函数, 接收 transformer pipeline 的tokenizer.
  • shap.maskers._text.Text#__call__(self, mask:List[boolean], s:str)
    def __call__(self, mask:List[boolean], s:str)
        self._update_s_cache(s)
    
        for i, v in enumerate(mask):
            # mask ignores separator tokens and keeps them unmasked
            if v or sep_token == self._segments_s[i]:
                out_parts.append(self._segments_s[i])
                is_previous_appended_token_mask_token = False
            else:
                if not self.collapse_mask_token or (
                        self.collapse_mask_token and not is_previous_appended_token_mask_token):
                    out_parts.append(" " + self.mask_token)
                    is_previous_appended_token_mask_token = True
        # 前缀,后缀的 '' 被消除, 长度回归到M
        out = "".join(out_parts)
        result = np.array([out])
        return (result,)
    
    def _update_s_cache(self, s):
        tokens, token_ids = self.token_segments(s)
        self._tokenized_s = np.array(token_ids)
        # ['' 'What ' 'a ' 'great ' 'movie' '! ' 'if ' 'you ' 'have ' 'no ' 'taste', '.' '']
        self._segments_s = np.array(tokens)
    
    def token_segments(self, s):
        # s: 'What a great movie! if you have no taste.'
        token_data = self.tokenizer(s, return_offsets_mapping=True)
        offsets = token_data["offset_mapping"]
        offsets = [(0, 0) if o is None else o for o in offsets]
        parts = [s[offsets[i][0]:max(offsets[i][1], offsets[i + 1][0])] for i in range(len(offsets) - 1)]
        parts.append(s[offsets[len(offsets) - 1][0]:offsets[len(offsets) - 1][1]])
        # ['', 'What ', 'a ', 'great ', 'movie', '! ', 'if ', 'you ', 'have ', 'no ', 'taste', '.', '']
        return parts, token_data["input_ids"]
    
    
    

s 是原始 token, 会传入 tokenizer.__call__(s) 做处理, 然后长度与 mask 一致, 配合使用. 返回还是 text,如['What a great movie! [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]'] .

  • shap.maskers._text.Text.clustering(self,s)
    Build a heriarchial clustering of tokens that align with sentence structure.
    内部又会调用下方 partition_tree 函数.
  • shap.maskers._text.partition_tree(decoded_tokens, special_tokens=None)
    构造一颗层次聚类树, 用途是划分, 也可以叫划分树.
    该树以 ndarray存储, shape=(M - 1, 4). 其中M=len(decoded_tokens).
    道理为 “拥有M个叶子节点的二叉树, 有 M-1个非叶子节点” .
    所以它只存储了必要的非叶子节点信息.

MaskedModel

shap.utils._masked_model.MaskedMode.
This is a utility class that combines a model, a masker object, and a current input.

  • MaskedMode#__init__(self, model, masker, link, linearize_link, *args)
    model 是 shap.models._transformers_pipeline.TransformersPipeline, *args 是text.
  • MaskedModel#__call__(self, masks:ndarray[bool],)
    会调用下方函数 _full_masking_call() .
    • MaskedModel#_full_masking_call(self, masks, zero_index=None, batch_size=None)
      masks即batch_masks,
    def _full_masking_call(self, masks, batch_size=None):
    	
        for i, mask in enumerate(masks):
            # len(mask)=M+2
            # self.args 是str, token 个数为 M
            # masker.__call__() 内部会再调 tokenizer.__call__(), 总之搞得 len(masked_inputs) 又变成了 M
            # 所以, 只要 mask1[1:-1] ==  mask2[1:-1], 得到的 masked_inputs 就是一致的.
            masked_inputs = self.masker.__call__(mask, *self.args)
            for i in range(len(masked_inputs)):
                all_masked_inputs[i].append(masked_inputs[i])
        joined_masked_inputs = self._stack_inputs(*all_masked_inputs)
        # model为 Pipeline 对象, 内部又会 tokenizer处理. input_ids 长度为 M+2
        outputs = self.model(*joined_masked_inputs)
        
        _build_fixed_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, self.link)
        return averaged_outs
    
  • shap.utils._masked_model.make_masks(cluster_matrix)
    根据聚类层次树, 返回相应的mask列表, 是个 scipy.sparse.csr_matrix, 可转化为普通矩阵, 此时shape = [M+M-1,M]. 代表着树中叶子节点(原始的 token) 及非叶子节点 (动态生成的 短语).

四. demo的断点追踪

explain_row() 总方法

shap对模型的 单个样本的预测 做解释, 所以入口方法为shap.explainers._partition.Partition.explain_row() .
内部会构造 上文提到的 MaskedModel 对象.

def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent,
                fixed_context=None):
	# row_args ('What a great movie! if you have no taste.',), 有效 token 数为11.
	
    # MaskedModel 对象的生命周期较短, 跟着一个具体的样本走.
    # 构造函数中会调用 tokenizer.__call__(rowargs), 加上 [CLS],[SEP], 长度由 11 -> 13.
    fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)

    # M=13
    M = len(fm)
    # mask 全为 False, 表示都不是有效输入.
    m00 = np.zeros(M, dtype=np.bool)
    # masked odel.__call__()
    self._curr_base_value = fm(m00.reshape(1, -1))[0]
    # mask 全为 True, 表示均为有效输入.
    f11 = fm(~m00.reshape(1, -1))[0]

    if callable(self.masker.clustering):
        self._clustering = self.masker.clustering(*row_args)
        self._mask_matrix = make_masks(self._clustering)

    out_shape = (2 * self._clustering.shape[0] + 1,)

    if max_evals == "auto":
        max_evals = 500

    self.values = np.zeros(out_shape)
    self.dvalues = np.zeros(out_shape)

    self.owen(fm, self._curr_base_value, f11, max_evals // 2 - 2, outputs, fixed_context, batch_size, silent)

    self.values[:] = self.dvalues

    def lower_credit(i, value=0):
        if i < M:
            self.values[i] += value
            return
        li = int(self._clustering[i - M, 0])
        ri = int(self._clustering[i - M, 1])
        group_size = int(self._clustering[i - M, 3])
        lsize = int(self._clustering[li - M, 3]) if li >= M else 1
        rsize = int(self._clustering[ri - M, 3]) if ri >= M else 1
        assert lsize + rsize == group_size
        self.values[i] += value
        lower_credit(li, self.values[i] * lsize / group_size)
        lower_credit(ri, self.values[i] * rsize / group_size)
    # 递归函数调用, 也是更新 self.values 的关键地方
    lower_credit(len(self.dvalues) - 1)

    return {
        "values": self.values[:M].copy(),
        "expected_values": self._curr_base_value if outputs is None else self._curr_base_value[outputs],
        "mask_shapes": [s + out_shape[1:] for s in fm.mask_shapes],
        "main_effects": None,
        "hierarchical_values": self.dvalues.copy(),
        "clustering": self._clustering,
        "output_indices": outputs,
        "output_names": getattr(self.model, "output_names", None)
    }

1.1 计算 base_value

MaskedModel.__call__(masks=[False,…,False]), 内部会调 被wrap起来的 hugging face transformers model ,获取全mask输入时的输出 self._curr_base_value=[ 0.69096287 -0.690963 ].

1.2 计算原样输出 f11

MaskedModel.__call__(masks=[True…]), 获取原样输入时的输出 f11=[-8.77722551 8.77673741].

1.3 构造 heriarchial clustering 划分树

partition_tree=shap.maskers._text.Text.clustering(s=‘What a great movie! if you have no taste.’)
Build a heriarchial clustering of tokens that align with sentence structure.

其内部会先把 输入预处理为 bert 接受的形式:
decoded_x=['[CLS]', 'what', 'a', 'great', 'movie', '!', 'if', 'you', 'have', 'no', 'taste', '.', '[SEP]'], shape=(13,)
再调用 pt=shap.maskers._text.partition_tree(decoded_tokens, special_tokens=None) 作返回.

下面展开 partition_tree 的逻辑.
内部依赖 shap.maskers._text.merge_score(group1, group2, special_tokens), 该方法计算两个group合并后的得分, 得分越大表明越倾向于合并组成新的子句.
伊始, 每个group为相应位置token本身, 迭代 merge_score 的过程不断产生更长的子句, 就是在做层次聚类.
初始每个token的编号见下:

0 [CLS]
1 what
2 a
3 great
4 movie
5 !
6 if
7 you
8 have
9 no
10 taste
11 .
12 [SEP]

迭代中新的子句编号从 M开始作顺延, 生成的子句及相应的 partition_tree 为

13 [what, a]
14 [great, movie]
15 [if, you]
16 [have, no]
17 [have, no, taste]
18 [what, a, great, movie]
19 [if, you, have, no, taste]
20 [what, a, great, movie, !]
21 [if, you, have, no, taste, .]
22 [what, a, great, movie, !, if, you, have, no, taste, .]
23 [[CLS], what, a, great, movie, !, if, you, have, no, taste, .]
24 [[CLS], what, a, great, movie, !, if, you, have, no, taste, ., [SEP]]
------------------------------------------------
[[ 1.          2.          0.15384615  2.        ]
 [ 3.          4.          0.15384615  2.        ]
 [ 6.          7.          0.15384615  2.        ]
 [ 8.          9.          0.15384615  2.        ]
 [16.         10.          0.23076923  3.        ]
 [13.         14.          0.30769231  4.        ]
 [15.         17.          0.38461538  5.        ]
 [18.          5.          0.38461538  5.        ]
 [19.         11.          0.46153846  6.        ]
 [20.         21.          0.84615385 11.        ]
 [ 0.         22.          0.92307692 12.        ]
 [23.         12.          1.         13.        ]]
 --------------------------------------------------------

1.4 构造 mask 矩阵

把构建的划分树传入该方法, 得到 mask 矩阵.
self._mask_matrix = shap.utils._masked_model.make_masks(self._clustering), shape = [2*M-1,M], 即 [25,13].
它是 scipy.sparse.csr_matrix 对象, 转化为普通矩阵后打印如下, 其实就代表着上面 原始的token及动态生成的 短语.

[[ True False False False False False False False False False False False False]
 [False  True False False False False False False False False False False False]
 [False False  True False False False False False False False False False False]
 [False False False  True False False False False False False False False False]
 [False False False False  True False False False False False False False False]
 [False False False False False  True False False False False False False False]
 [False False False False False False  True False False False False False False]
 [False False False False False False False  True False False False False False]
 [False False False False False False False False  True False False False False]
 [False False False False False False False False False  True False False False]
 [False False False False False False False False False False  True False False]
 [False False False False False False False False False False False  True False]
 [False False False False False False False False False False False False  True]
 [False  True  True False False False False False False False False False False]
 [False False False  True  True False False False False False False False False]
 [False False False False False False  True  True False False False False False]
 [False False False False False False False False  True  True False False False]
 [False False False False False False False False  True  True  True False False]
 [False  True  True  True  True False False False False False False False False]
 [False False False False False False  True  True  True  True  True False False]
 [False  True  True  True  True  True False False False False False False False]
 [False False False False False False  True  True  True  True  True  True False]
 [False  True  True  True  True  True  True  True  True  True  True  True False]
 [ True  True  True  True  True  True  True  True  True  True  True  True False]
 [ True  True  True  True  True  True  True  True  True  True  True  True  True]]

1.5 调用owen方法

shap.explainers._partition.Partition.owen(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent)
该方法内有个 while eval_count < max_evals 循环, 循环体内依据上文的 mask_matrix 构造本个 batch 内的 batch_masks, 传入model作计算.
该方法体内有很多逻辑分支, 现精炼解读见下:

def owen(self, fixed_context=None):
	# M=13
    # ind 初始为 2*(M-1)=2*12=24
    q = queue.PriorityQueue()
    # (0, 0, (m00, f00, f11, ind, 1.0))
    # 第一个0给优先队列排序用, 第二个0是冗余
    q.put((0, 0, (m00, f00, f11, ind, 1.0)))

    while not q.empty():
        # 耗时控制
        if eval_count >= max_evals:
            while not q.empty():
                m00, f00, f11, ind, weight = q.get()[2]
                # 大多数情况下eval_count总是会超出 max_evals 的, 所以该语句块会多次执行
                # 这里的 ind通常>M, 是非叶子节点
                self.dvalues[ind] += (f11 - f00) * weight
        	break

        # len(tuple)=9
        batch_args = []
        while not q.empty() and len(batch_masks) < batch_size and eval_count < max_evals:
            # get()会使元素出队列,[2]对应的是5元组 args
            m00, f00, f11, ind, weight = q.get()[2]
            # get the left and right children of this cluster
            # 首次执行, ind-M 就是层次树的根节点
            lind = int(self._clustering[ind - M, 0]) if ind >= M else -1
            rind = int(self._clustering[ind - M, 1]) if ind >= M else -1

            # get the distance of this cluster's children
            if ind < M:
                distance = -1
            else:
            	# 这里也有可能为负
                distance = self._clustering[ind - M, 2]
            if distance < 0:
                self.dvalues[ind] += (f11 - f00) * weight
                continue

            # 其实就是 m10 += self._mask_matrix[lind], 这样是为了少建对象, 减轻GC压力?
            m10[:] += self._mask_matrixself[lind, :]
            m01[:] += self._mask_matrix[rind, :]
            # len(batch_masks) 恒等于 2*len(batch_masks)
            batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight))
            batch_masks.append(m10)
            batch_masks.append(m01)
        # outer
        batch_masks = np.array(batch_masks)
        # 模型infer() !
        fout = fm(batch_masks)
        # 样本数计数
        eval_count += len(batch_masks)

        # use the results of the batch to add new nodes
        for i in range(len(batch_args)):
            m00, m10, m01, f00, f11, ind, lind, rind, weight = batch_args[i]
            # get the evaluated model output on the two new masked inputs
            f10 = fout[2 * i]
            f01 = fout[2 * i + 1]

            new_weight = weight
            if fixed_context is None:
                new_weight /= 2

            if fixed_context is None or fixed_context == 0:
                # recurse on the left node with zero context
                args = (m00, f00, f10, lind, new_weight)
                # 优先队列默认升序排序, 为了降序取值, 所以作负值处理
                # 既有得分大小的考虑, 又有节点所在层级考虑, 以它为根的树越深, 权重越大
                q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))

                # recurse on the right node with zero context
                args = (m00, f00, f01, rind, new_weight)
                q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))

            if fixed_context is None or fixed_context == 1:
                # recurse on the left node with one context
                args = (m01, f01, f11, lind, new_weight)
                q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))

                # recurse on the right node with one context
                args = (m10, f10, f11, rind, new_weight)
                q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))

1.6 lower_credit

定义在 explain_row() 中的一段函数.
owen 中有复杂度控制来约束耗时, 所以有些叶子节点没有被充分计算.
这里就是把一个簇的贡献递归地摊到叶子节点上.

def lower_credit(i, value=0):
        if i < M:
            self.values[i] += value
            return
        li = int(self._clustering[i - M, 0])
        ri = int(self._clustering[i - M, 1])
        group_size = int(self._clustering[i - M, 3])
        lsize = int(self._clustering[li - M, 3]) if li >= M else 1
        rsize = int(self._clustering[ri - M, 3]) if ri >= M else 1
        assert lsize + rsize == group_size
        self.values[i] += value
        lower_credit(li, self.values[i] * lsize / group_size)
        lower_credit(ri, self.values[i] * rsize / group_size)

讨论

cache 加速讨论

owen()内会构造batch_masks 并传入 MaskedModel(batch_masks),
batch_masks 所盛放的这些masks是有特点的, 会存在若干对 (i,j) , 满足 masks[i][1:-1] == masks[j][1:-1], 所以可搞缓存加速.

理由如下:
MaskedModel.call()内部在构造 model_pipeline 的 masked_inputs 时, 会将 前导的[CLS], 最后的[SEP]抹去, 所以虽然 mask[i] 和 mask[j] 的结果必然一致.

在这里插入图片描述
图 观察到的现象是, 在一个batch_mask内 ,对于同样的 有效mask部分(选中高亮), 会重复四次. 所以缓存的话, 加速比是 4:1, 即提速3倍.

Q: 不同 batch_mask 间有无重复?
没有, 这是 owen的 q.put()决定的.
所以缓存加速的 scope 限定在 batch内即可.
batch大小是有规律的, (2,4,8,32,64), 为啥没 16? todo.

print 信息

迭代过程中, 每个mask矩阵明细及对应模型输出的信息见下, 含1.1, 1.2 部分的调用 .

enter owen invoking -------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015418155817314982}, {'label': 'POSITIVE', 'score': 0.9998457431793213}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.6661808490753174}, {'label': 'POSITIVE', 'score': 0.33381909132003784}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[ True  True  True  True  True  True  True  True  True  True  True  True False]
 [False False False False False False False False False False False False  True]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.6661808490753174}, {'label': 'POSITIVE', 'score': 0.33381909132003784}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015418155817314982}, {'label': 'POSITIVE', 'score': 0.9998457431793213}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.6661808490753174}, {'label': 'POSITIVE', 'score': 0.33381909132003784}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015418155817314982}, {'label': 'POSITIVE', 'score': 0.9998457431793213}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[ True False False False False False False False False False False False False]
 [False  True  True  True  True  True  True  True  True  True  True  True False]
 [ True False False False False False False False False False False False  True]
 [False  True  True  True  True  True  True  True  True  True  True  True  True]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00021430315973702818}, {'label': 'POSITIVE', 'score': 0.9997857213020325}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9538852572441101}, {'label': 'POSITIVE', 'score': 0.046114809811115265}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00021430315973702818}, {'label': 'POSITIVE', 'score': 0.9997857213020325}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9538852572441101}, {'label': 'POSITIVE', 'score': 0.046114809811115265}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00021430315973702818}, {'label': 'POSITIVE', 'score': 0.9997857213020325}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9538852572441101}, {'label': 'POSITIVE', 'score': 0.046114809811115265}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[False  True  True  True  True  True False False False False False False False]
 [False False False False False False  True  True  True  True  True  True False]
 [False  True  True  True  True  True False False False False False False  True]
 [False False False False False False  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True False False False False False False False]
 [ True False False False False False  True  True  True  True  True  True False]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00021430315973702818}, {'label': 'POSITIVE', 'score': 0.9997857213020325}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9538852572441101}, {'label': 'POSITIVE', 'score': 0.046114809811115265}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00020252222020644695}, {'label': 'POSITIVE', 'score': 0.999797523021698}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9573299884796143}, {'label': 'POSITIVE', 'score': 0.04266996309161186}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00020252222020644695}, {'label': 'POSITIVE', 'score': 0.999797523021698}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9573299884796143}, {'label': 'POSITIVE', 'score': 0.04266996309161186}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[ True  True  True  True  True  True False False False False False False  True]
 [ True False False False False False  True  True  True  True  True  True  True]
 [False  True  True  True  True False  True  True  True  True  True  True  True]
 [False False False False False  True  True  True  True  True  True  True  True]
 [ True  True  True  True  True False  True  True  True  True  True  True False]
 [ True False False False False  True  True  True  True  True  True  True False]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00020252222020644695}, {'label': 'POSITIVE', 'score': 0.999797523021698}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9573299884796143}, {'label': 'POSITIVE', 'score': 0.04266996309161186}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00020252222020644695}, {'label': 'POSITIVE', 'score': 0.999797523021698}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9573299884796143}, {'label': 'POSITIVE', 'score': 0.04266996309161186}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.0006543286144733429}, {'label': 'POSITIVE', 'score': 0.9993457198143005}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.7347449660301208}, {'label': 'POSITIVE', 'score': 0.26525506377220154}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[False  True  True  True  True False  True  True  True  True  True  True False]
 [False False False False False  True  True  True  True  True  True  True False]
 [ True  True  True  True  True False  True  True  True  True  True  True  True]
 [ True False False False False  True  True  True  True  True  True  True  True]
 [ True  True  True  True  True False False False False False False False  True]
 [ True False False False False  True False False False False False False  True]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.0006543286144733429}, {'label': 'POSITIVE', 'score': 0.9993457198143005}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.7347449660301208}, {'label': 'POSITIVE', 'score': 0.26525506377220154}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.0006543286144733429}, {'label': 'POSITIVE', 'score': 0.9993457198143005}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.7347449660301208}, {'label': 'POSITIVE', 'score': 0.26525506377220154}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a great movie[MASK][MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.0006543286144733429}, {'label': 'POSITIVE', 'score': 0.9993457198143005}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK][MASK][MASK]! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.7347449660301208}, {'label': 'POSITIVE', 'score': 0.26525506377220154}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[False  True  True  True  True False False False False False False False  True]
 [False False False False False  True False False False False False False  True]
 [ True  True  True  True  True False False False False False False False False]
 [ True False False False False  True False False False False False False False]
 [False  True  True  True  True False False False False False False False False]
 [False False False False False  True False False False False False False False]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.4105953574180603}, {'label': 'POSITIVE', 'score': 0.5894045829772949}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015325885033234954}, {'label': 'POSITIVE', 'score': 0.9998466968536377}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.4105953574180603}, {'label': 'POSITIVE', 'score': 0.5894045829772949}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015325885033234954}, {'label': 'POSITIVE', 'score': 0.9998466968536377}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.4105953574180603}, {'label': 'POSITIVE', 'score': 0.5894045829772949}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015325885033234954}, {'label': 'POSITIVE', 'score': 0.9998466968536377}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[ True  True  True False False  True  True  True  True  True  True  True  True]
 [ True False False  True  True  True  True  True  True  True  True  True  True]
 [False  True  True False False  True  True  True  True  True  True  True False]
 [False False False  True  True  True  True  True  True  True  True  True False]
 [False  True  True False False  True  True  True  True  True  True  True  True]
 [False False False  True  True  True  True  True  True  True  True  True  True]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK]! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.4105953574180603}, {'label': 'POSITIVE', 'score': 0.5894045829772949}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie! if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00015325885033234954}, {'label': 'POSITIVE', 'score': 0.9998466968536377}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9434953927993774}, {'label': 'POSITIVE', 'score': 0.05650459602475166}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00016879562463145703}, {'label': 'POSITIVE', 'score': 0.9998311996459961}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9434953927993774}, {'label': 'POSITIVE', 'score': 0.05650459602475166}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00016879562463145703}, {'label': 'POSITIVE', 'score': 0.9998311996459961}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[ True  True  True False False  True  True  True  True  True  True  True False]
 [ True False False  True  True  True  True  True  True  True  True  True False]
 [ True  True  True False False False  True  True  True  True  True  True False]
 [ True False False  True  True False  True  True  True  True  True  True False]
 [False  True  True False False False  True  True  True  True  True  True False]
 [False False False  True  True False  True  True  True  True  True  True False]]
==================================================
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9434953927993774}, {'label': 'POSITIVE', 'score': 0.05650459602475166}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00016879562463145703}, {'label': 'POSITIVE', 'score': 0.9998311996459961}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK][MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.9434953927993774}, {'label': 'POSITIVE', 'score': 0.05650459602475166}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie[MASK]if you have no taste.
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.00016879562463145703}, {'label': 'POSITIVE', 'score': 0.9998311996459961}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs What a [MASK][MASK]! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.0912800058722496}, {'label': 'POSITIVE', 'score': 0.9087200164794922}]
------------------------------------------------------------
yichu,transformers.pipelines.base.Pipeline.run_single,inputs [MASK][MASK]great movie! [MASK][MASK][MASK][MASK][MASK][MASK]
yichu,transformers.pipelines.base.Pipeline.run_single,outputs [{'label': 'NEGATIVE', 'score': 0.0002873070479836315}, {'label': 'POSITIVE', 'score': 0.9997126460075378}]
------------------------------------------------------------
for i, mask in enumerate(masks) done,masks=  [[ True  True  True False False False  True  True  True  True  True  True  True]
 [ True False False  True  True False  True  True  True  True  True  True  True]
 [False  True  True False False False  True  True  True  True  True  True  True]
 [False False False  True  True False  True  True  True  True  True  True  True]
 [ True  True  True False False  True False False False False False False False]
 [ True False False  True  True  True False False False False False False False]]
==================================================
ends owen invoking -------------------------------
shap_values
.values =
array([[[ 0.        ,  0.        ],
        [-1.22018141,  1.22019897],
        [-1.22018141,  1.22019897],
        [-3.89365494,  3.8936278 ],
        [-3.89365494,  3.8936278 ],
        [-0.24507859,  0.24491016],
        [ 0.16742733, -0.16747737],
        [ 0.16742733, -0.16747737],
        [ 0.16742733, -0.16747737],
        [ 0.16742733, -0.16747737],
        [ 0.16742733, -0.16747737],
        [ 0.16742733, -0.16747737],
        [ 0.        ,  0.        ]]])

.base_values =
array([[ 0.6909618 , -0.69096206]])

.data =
array([['', 'What ', 'a ', 'great ', 'movie', '! ', 'if ', 'you ', 'have ', 'no ', 'taste', '.',
        '']], dtype='<U6')
==========
positive_shap_values
.values =
array([ 0.        ,  1.22019897,  1.22019897,  3.8936278 ,  3.8936278 ,  0.24491016, -0.16747737,
       -0.16747737, -0.16747737, -0.16747737, -0.16747737, -0.16747737,  0.        ])

.base_values =
-0.6909620648280776

.data =
array(['', 'What ', 'a ', 'great ', 'movie', '! ', 'if ', 'you ', 'have ', 'no ', 'taste', '.', ''],
      dtype='<U6')
==========
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值