nn.transformer源码中
原状:
将need_weights改为True,这样multiheadattention才可以输出注意力权重。
然后使用hook提取中间层输出,示例代码如下:
import torch
import torch.nn as nn
num_heads = 4
input_dim = 16
model = nn.TransformerEncoder(nn.TransformerEncoderLayer(input_dim, num_heads),6)
print(model)
query = torch.randn(10, 8, input_dim)
features_in_hook = []
features_out_hook = []
def hook(module, fea_in, fea_out):
features_in_hook.append(fea_in) #去掉这行就不会留下输入了
features_out_hook.append(fea_out)
return None
# for (name, module) in model.named_modules(): #看看各层的名字
# print(name)
# layer_name = 'layers.5.self_attn'
# for (name, module) in model.named_modules():
# if name == layer_name:
# module.register_forward_hook(hook=hook)
model.layers[-1].self_attn.register_forward_hook(hook) #这样也可以代替上面四行,如果很多层,用上面的
c = model(query)