关于expand_as()函数
RuntimeError: The expanded size of the tensor must match the existing size at non-singleton dimension 1. Target sizes:. Tensor sizes:
本质上,
"""
前提:
a.shape为[6272, 64]
b.shape为[50, 25088, 64]
"""
a_1 = a.unsqueeze(0)
# 此时,a_1.shape为[1, 6272, 64]
a_2 = a_1.expand_as(b)
核心在于,进行expand_as操作时,a_1和b之间,存在两个维度的尺寸不符,无法进行该操作
此时应该保证其他维度尺寸相符,再进行expand_as()操作
b_shape = np.array(b.shape) # 保证b.shape以数组形式存储
b_shape = [b_shape[1], b_shape[2]] # 只保留向量b第二维和第三维的尺寸
a = a.expand(4, -1, -1).reshape(*b_shape) # 此处的4 = b_shape[1] / a.shape[0]
# 此时保证了其他维度的相同,直接expand_as()即可
a = a.unsqueeze(0).expand_as(b)
另外,
a = a.expand(4, -1, -1).reshape(*b_shape) # 此处的4 = b_shape[1] / a.shape[0]
原因在于