def get_mean_emb(hidden_states, mask):
s = torch.sum(hidden_states * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
bert embedding取平均
最新推荐文章于 2024-07-23 16:23:28 发布
![](https://img-home.csdnimg.cn/images/20240711042549.png)
def get_mean_emb(hidden_states, mask):
s = torch.sum(hidden_states * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d