如果对 torch.unsqueeze 了解不多,建议先阅读博主的这篇文章。
1. Tensor 和 Scalar
import torch
X = torch.tensor(
[[1,2,3],
[4,5,6]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)
Y = torch.tensor(
7
)
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)
Z1 = X*Y
print("\nZ:\n",Z1)
print("Z1's shape:\n",Z1.shape)
Z2 = Y*X
print("\nZ:\n",Z2)
print("Z2's shape:\n",Z1.shape)
输出:
X:
tensor([[1, 2, 3],
[4, 5, 6]])
X's shape:
torch.Size([2, 3])
Y:
tensor(7)
Y's shape:
torch.Size([])
Z:
tensor([[ 7, 14, 21],
[28, 35, 42]])
Z1's shape:
torch.Size([2, 3])
Z:
tensor([[ 7, 14, 21],
[28, 35, 42]])
Z2's shape:
torch.Size([2, 3])
直接将 scalar 广播成 shape 为 [2,3] 的 tensor 了。
2. Tensor与Tensor: [2,3]*[3]
仅修改下列代码
Y = torch.tensor(
[7,8,9]
)
结果
X:
tensor([[1, 2, 3],
[4, 5, 6]])
X's shape:
torch.Size([2, 3])
Y:
tensor([7, 8, 9])
Y's shape:
torch.Size([3])
Z:
tensor([[ 7, 16, 27],
[28, 40, 54]])
Z1's shape:
torch.Size([2, 3])
Z:
tensor([[ 7, 16, 27],
[28, 40, 54]])
Z2's shape:
torch.Size([2, 3])
将 shape 为 [3] 的tensor 广播为 [2,3]
3. Tensor与Tensor [2,3]*[3,1], 报错
仅仅修改下列代码:
Y = torch.tensor(
[[7],[8],[9]]
)
输出:
Y:
tensor([[7],
[8],
[9]])
Y's shape:
torch.Size([3, 1])
Traceback (most recent call last):
File "D:/Venv/Test/0710Test/Test.py", line 16, in <module>
Z = X*Y
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
这个报错我暂时还没搞懂。
4. X.shape: [5,3] Y.shape [5,6] 如何 broadcast ?
import torch
X = torch.tensor(
[[1,2,3],
[4,5,6],
[7,8,9],
[10,11,12],
[13,14,15]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)
X1 = X.unsqueeze(1)
print("X1:\n",X1)
print("X1's shape:\n",X1.shape,"\n")
Y = torch.ones((5,6))
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)
Y1 = Y.unsqueeze(2)
print("Y1:\n",Y1)
print("Y1's shape:\n",Y1.shape,"\n")
Z1 = X.unsqueeze(1)*Y.unsqueeze(2);
print("\nZ1:\n",Z1)
print("Z1's shape:",Z1.shape,"\n\n\n")
Z2 = X*Y
输出:
X:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[13, 14, 15]])
X's shape:
torch.Size([5, 3])
X1:
tensor([[[ 1, 2, 3]],
[[ 4, 5, 6]],
[[ 7, 8, 9]],
[[10, 11, 12]],
[[13, 14, 15]]])
X1's shape:
torch.Size([5, 1, 3])
Y:
tensor([[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1.]])
Y's shape:
torch.Size([5, 6])
Y1:
tensor([[[1.],
[1.],
[1.],
[1.],
[1.],
[1.]],
[[1.],
[1.],
[1.],
[1.],
[1.],
[1.]],
[[1.],
[1.],
[1.],
[1.],
[1.],
[1.]],
[[1.],
[1.],
[1.],
[1.],
[1.],
[1.]],
[[1.],
[1.],
[1.],
[1.],
[1.],
[1.]]])
Y1's shape:
torch.Size([5, 6, 1])
Z1:
tensor([[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.],
[ 4., 5., 6.]],
[[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.],
[10., 11., 12.]],
[[13., 14., 15.],
[13., 14., 15.],
[13., 14., 15.],
[13., 14., 15.],
[13., 14., 15.],
[13., 14., 15.]]])
Z1's shape: torch.Size([5, 6, 3])
Traceback (most recent call last):
File "D:/Venv/Test/0710Test/Test.py", line 29, in <module>
Z2 = X*Y
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 1
Process finished with exit code 1
上面的 X 相当于 rays_direction 向量,维度是 [num_rays, 3],表示所有光线的方向向量;
上面的 Y 相当于 z_val 向量,维度是 [num_rays,num_sample], 表示所有光线的采样点的距离。
那么针对每条光线的每个采样点,都需要乘上 X 中的某一行,相当于是
z
x
⃗
z\vec{x}
zx 的操作。
X:[5,3]
Y[5,6]
X.unsqueeze(1)*Y.unsqueeze(2):[5,6,3]
X.unsqueeze(1):
[
5
,
3
]
→
[
5
,
1
,
3
]
[5,3] \rightarrow [5,1,\mathbf{3}]
[5,3]→[5,1,3]
Y.unsqueeze(2) :
[
5
,
6
]
→
[
5
,
6
,
1
]
[5,6] \rightarrow [5,\mathbf{6},1]
[5,6]→[5,6,1]
在执行 X.unsqueeze(1)*Y.unsqueeze(2): 的时候:
X.unsqueeze(1) 、Y.unsqueeze(2)都会被 broadcast 成 [5,6,3] 的 shape
看上面两行,X.unsqueeze(1)与Y.unsqueeze(2)相乘, 结果 tensor 的 sahpe 的每一个维度的值取它们两个对应值的较大值。即:
dim 1:
max
{
1
,
6
}
=
6
\max\{1,6\}=6
max{1,6}=6
dim 2:
max
{
3
,
1
}
=
3
\max\{3,1\}=3
max{3,1}=3
理解 unsqueeze 的很好的一个例子
经过 element-wise 乘法后,即 “*” 乘法后,
得到的结果的 shape 和 参与乘法运算的tensor的较大 dim 的 tensor的shape一致
NeRF 里面一个 weight 矩阵,它的 shape 是:N_rays * N_samples
还有一个 RGB tensor,它的shape 是:N_rays * N_samples*3
根据 Volume Rendering,同一个点的 RGB 的系数是一样的
见下述代码:
nRays = 2
nSamples = 3
weight = torch.tensor([[1,2,3],[7,4,8]])
RGB = torch.tensor([
[[0,3,0],[4,5,6],[7,8,9]],
[[10,11,18],[8,58,6],[70,82,9]]
])
print("weight's shape",weight.shape)
print("weight",weight)
print("\n", "RGB's shape",RGB.shape)
print("RGB",RGB,"\n")
weight_after_unsqueeze = weight.unsqueeze(-1)
print("weight_after_unsqueeze's shape: ",weight_after_unsqueeze.shape)
print("weight_after_unsqueeze",weight_after_unsqueeze)
res2 = weight_after_unsqueeze*RGB
print("shape: ",res2.shape)
print(res2)
输出:
weight's shape torch.Size([2, 3])
weight tensor([[1, 2, 3],
[7, 4, 8]])
RGB's shape torch.Size([2, 3, 3])
RGB tensor([[[ 0, 3, 0],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 11, 18],
[ 8, 58, 6],
[70, 82, 9]]])
weight_after_unsqueeze's shape: torch.Size([2, 3, 1])
weight_after_unsqueeze tensor([[[1],
[2],
[3]],
[[7],
[4],
[8]]])
shape: torch.Size([2, 3, 3])
tensor([[[ 0, 3, 0],
[ 8, 10, 12],
[ 21, 24, 27]],
[[ 70, 77, 126],
[ 32, 232, 24],
[560, 656, 72]]])
上述代码中 N_rays 和 N_weights 分别取 2、3.
在 weight 经过 unsqueeze 后,得到的tensor的形式,竖着排的形式,会给人提示,是一个数占了一列。RGB三个分量都会乘上同一个数字,同一列的。