2021SC@SDUSC
在前面的train.py的evaluate_model模块中,引用了Reasoning_AMR_CN_DUAL
模块:
eval_model = Reasoning_AMR_CN_DUAL
下面分析一下models/reasoningCT.PY这个程序中的Reasoning_AMR_CN_DUAL
模块。Reasoning_AMR_CN_DUAL是reasoningCT.PY程序中定义的一个类。
它是基于ACP图的一个描述。
其中定义了一个函数encoder_attn,注意力编码模块。用于计算关联关系的注意力值。
def encoder_attn(self, inp):
with torch.no_grad():
concept_repr = self.embed_scale * self.concept_encoder(inp['concept'],
inp['concept_char'] + self.concept_depth(
inp['concept_depth']))
concept_repr = self.concept_embed_layer_norm(concept_repr)
concept_mask = torch.eq(inp['concept'], self.vocabs['concept'].padding_idx)
relation = self.relation_encoder(inp['relation_bank'], inp['relation_length'])
relation[0, :] = 0.
relation = relation[inp['relation']]
sum_relation = relation.sum(dim=3)
num_valid_paths = inp['relation'].ne(0).sum(dim=3).clamp_(min=1)
divisor = (num_valid_paths).unsqueeze(-1).type_as(sum_relation)
relation = sum_relation / divisor
attn, attn_weights = self.graph_encoder.get_attn_weights(concept_repr, relation, self_padding_mask=concept_mask)
return attn
这里调用了graph_transformer.py程序的GraphTransformer类的get_attn_weights方法。下面分析一下。
它通过nn.Module类来定义模型,并计算注意力权值。
总结一下:到这篇,ACP_CSQA这个项目的学习就告一个段落了,这个学习过程自己接触了很多没有学过的领域,也清楚地知道还有很多东西要学,要深入了解和掌握一个领域的技术不是一件容易的事情,希望自己在今后加油努力。