import torch.nn.functional as F distance = F.pairwise_distance(rep_a, rep_b, p=2) 其中rep_a和rep_b为[batch_size,hidden_dim]