pytorch的torch.add()
及torch.split()
函数
import torch
# outputs是一个[batch, seq, 40]维的tensor,把outputs分割成两个[batch, seq, 20]的tensor,并每个元素求平均值
add = torch.add(*torch.split(outputs, 20, dim=2)) / 2
pytorch的torch.add()
及torch.split()
函数
import torch
# outputs是一个[batch, seq, 40]维的tensor,把outputs分割成两个[batch, seq, 20]的tensor,并每个元素求平均值
add = torch.add(*torch.split(outputs, 20, dim=2)) / 2