三维数组取mean可以得到二维,mean内嵌了一个torch.squeeze(),将数值为1的维度压缩
print(encoder_out.size()) # 50,196,2048
mean_encoder_out = encoder_out.mean(dim=1)
print(mean_encoder_out.size()) # 50,2048
三维数组取mean可以得到二维,mean内嵌了一个torch.squeeze(),将数值为1的维度压缩
print(encoder_out.size()) # 50,196,2048
mean_encoder_out = encoder_out.mean(dim=1)
print(mean_encoder_out.size()) # 50,2048