RuntimeError: cdist only supports 2D tensors, X1 got: 3D

RuntimeError: cdist only supports 2D tensors, X1 got: 3D

报错信息:RuntimeError: cdist only supports 2D tensors, X1 got: 3D
在使用torch.cdist()时报以上错误信息,具体程序是dist=torch.cdist(x.unsqueeze(0), x.unsqueeze(0)).squeeze(0),当时使用的pytorch版本是1.1,最后换用1.3版本程序正常运行。具体原因目前还没有找到,有同样问题的可以尝试换一下版本。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
修改后的代码如下: ```python class FaceKernelCorrelation(nn.Module): def __init__(self, num_kernel=64, sigma=0.2): super(FaceKernelCorrelation, self).__init__() self.num_kernel = num_kernel self.sigma = sigma self.weight_alpha = Parameter(torch.rand(1, num_kernel, 6) * np.pi) self.weight_beta = Parameter(torch.rand(1, num_kernel, 6) * 2 * np.pi) self.bn = nn.BatchNorm1d(num_kernel) self.relu = nn.ReLU() def forward(self, vertices, neighbor_index): b, _, n = vertices.size() center = vertices.unsqueeze(2).expand(-1, -1, self.num_kernel, -1).unsqueeze(4) neighbor = torch.gather(vertices.unsqueeze(3).expand(-1, -1, -1, 3), 2, neighbor_index.unsqueeze(1).expand(-1, 3, -1, -1)) neighbor = neighbor.unsqueeze(2).expand(-1, -1, self.num_kernel, -1, -1) # 计算直线特征 line = neighbor - center length = torch.sqrt(torch.sum(line**2, dim=-1, keepdim=True)) direction = line / (length + 1e-8) fea = torch.cat([center, direction, length], dim=4) fea = fea.unsqueeze(5).expand(-1, -1, -1, -1, -1, 6) weight = torch.cat([torch.sin(self.weight_alpha) * torch.cos(self.weight_beta), torch.sin(self.weight_alpha) * torch.sin(self.weight_beta), torch.cos(self.weight_alpha)], 0) weight = weight.unsqueeze(0).expand(b, -1, -1, -1) weight = weight.unsqueeze(3).expand(-1, -1, -1, n, -1) weight = weight.unsqueeze(4).expand(-1, -1, -1, -1, 6, -1) dist = torch.sum((fea - weight)**2, 1) fea = torch.sum(torch.sum(np.e**(dist / (-2 * self.sigma**2)), 4), 3) / 16 return self.relu(self.bn(fea)) ``` 对比原有的代码,主要修改的地方如下: 1. 修改了 weight_alpha 和 weight_beta 的形状,将其从 4 改为 6,以便存储直线特征; 2. 在 forward 函数中,首先计算出所有点的邻居点,然后根据邻居点和中心点计算出直线特征(即方向向量和长度),并将其拼接在一起; 3. 将拼接后的直线特征与权重相减后,进行距离计算和高斯加权求和。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值