defLm_div(multihead):
N =len(multihead)
similarity =0for i inrange(N):for j inrange(i+1,N):
similarity = similarity + torch.cosine_similarity(multihead[i].view(1,-1), multihead[j].view(1,-1), dim=1).abs()
loss = similarity/(N*N-N)return loss
Python——余弦相似性输入:一个张量列表def Lm_div(multihead): N = len(multihead) similarity = 0 for i in range(N): for j in range(i+1,N): similarity = similarity + torch.cosine_similarity(multihead[i].view(1,-1), multihead[j].view(1,-1), d