expand_as(attention_weights)
是一个张量方法,它会将调用它的张量(这里是mask.unsqueeze(1)
)扩展为与attention_weights
张量相同的形状。假设attention_weights
的形状是(batch_size, num_heads, seq_length, seq_length)
,那么经过expand_as
操作后,mask_expanded
的形状也会变成(batch_size, num_heads, seq_length, seq_length)
,使得它在每个注意力头上都有相同的掩码模式。- 例如
mask_expanded = mask.unsqueeze(1).expand_as(attention_weights).bool()
- mask.unsqueeze(1)后(32,43,43)-》(32,1,43,43)
- 假如attention_weights为(32,2,43,43)
- expand_as(attention_weights)后mask变为->(32,2,43,43)
expand_as操作
于 2024-07-27 20:25:25 首次发布