这个问题大多数是由于使用老版本torch导致的,github上有说这是某个版本的bug,再后续工作中已经进行了修复,要是非要使用该版本torch的话,我使用了一个比较笨拙的方法,写一个where代替torch.where()
代码如下:
def manual_where(condition, x, y):
# assert condition.shape == x.shape and condition.shape == y.shape, "Shapes of condition, x, and y must be the same."
assert condition.shape == x.shape or condition.shape == y.shape, "Shapes of condition, x, and y must be compatible."
assert condition.dtype == torch.bool, "Condition tensor must have dtype torch.bool."
if condition.shape != x.shape:
x = x.expand_as(condition)
if hasattr(y, 'shape')==True:
if condition.shape != y.shape:
y = y.expand_as(condition)
else:
y = y.expand_as(condition)
result = torch.empty_like(x)
indices = torch.nonzero(condition)
for index in indices:
result[tuple(index.tolist())] = x[tuple(index.tolist())]
indices = torch.nonzero(~condition)
for index in indices:
result[tuple(index.tolist())] = y[tuple(index.tolist())]
return result
这个版本是xy不需要同维度的,不过肯定没有torch写的快,不过能运行了
这是复现muzic中的getmusic时遇到并解决的问题