Tensor unsqueeze 以 broadcast

如果对 torch.unsqueeze 了解不多,建议先阅读博主的这篇文章

1. Tensor 和 Scalar

import torch

X = torch.tensor(
    [[1,2,3],
     [4,5,6]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)

Y = torch.tensor(
    7
)
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)

Z1 = X*Y
print("\nZ:\n",Z1)
print("Z1's shape:\n",Z1.shape)

Z2 = Y*X
print("\nZ:\n",Z2)
print("Z2's shape:\n",Z1.shape)

输出:

X:
 tensor([[1, 2, 3],
        [4, 5, 6]])
X's shape:
 torch.Size([2, 3])

Y:
 tensor(7)
Y's shape:
 torch.Size([])

Z:
 tensor([[ 7, 14, 21],
        [28, 35, 42]])
Z1's shape:
 torch.Size([2, 3])

Z:
 tensor([[ 7, 14, 21],
        [28, 35, 42]])
Z2's shape:
 torch.Size([2, 3])

直接将 scalar 广播成 shape 为 [2,3] 的 tensor 了。

2. Tensor与Tensor: [2,3]*[3]

仅修改下列代码

Y = torch.tensor(
    [7,8,9]
)

结果

X:
 tensor([[1, 2, 3],
        [4, 5, 6]])
        
X's shape:
 torch.Size([2, 3])

Y:
 tensor([7, 8, 9])
Y's shape:
 torch.Size([3])

Z:
 tensor([[ 7, 16, 27],
        [28, 40, 54]])
Z1's shape:
 torch.Size([2, 3])

Z:
 tensor([[ 7, 16, 27],
        [28, 40, 54]])
Z2's shape:
 torch.Size([2, 3])

将 shape 为 [3] 的tensor 广播为 [2,3]

3. Tensor与Tensor [2,3]*[3,1], 报错

仅仅修改下列代码:

Y = torch.tensor(
    [[7],[8],[9]]
)

输出:

Y:
 tensor([[7],
        [8],
        [9]])
Y's shape:
 torch.Size([3, 1])
Traceback (most recent call last):
  File "D:/Venv/Test/0710Test/Test.py", line 16, in <module>
    Z = X*Y
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

这个报错我暂时还没搞懂。

4. X.shape: [5,3] Y.shape [5,6] 如何 broadcast ?

import torch

X = torch.tensor(
    [[1,2,3],
     [4,5,6],
     [7,8,9],
     [10,11,12],
     [13,14,15]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)

X1 = X.unsqueeze(1)
print("X1:\n",X1)
print("X1's shape:\n",X1.shape,"\n")

Y = torch.ones((5,6))
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)


Y1 = Y.unsqueeze(2)
print("Y1:\n",Y1)
print("Y1's shape:\n",Y1.shape,"\n")

Z1 = X.unsqueeze(1)*Y.unsqueeze(2);
print("\nZ1:\n",Z1)
print("Z1's shape:",Z1.shape,"\n\n\n")
Z2 = X*Y

输出:

X:
 tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [13, 14, 15]])
X's shape:
 torch.Size([5, 3])
X1:
 tensor([[[ 1,  2,  3]],

        [[ 4,  5,  6]],

        [[ 7,  8,  9]],

        [[10, 11, 12]],

        [[13, 14, 15]]])
X1's shape:
 torch.Size([5, 1, 3]) 


Y:
 tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
Y's shape:
 torch.Size([5, 6])
Y1:
 tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]])
Y1's shape:
 torch.Size([5, 6, 1]) 


Z1:
 tensor([[[ 1.,  2.,  3.],
         [ 1.,  2.,  3.],
         [ 1.,  2.,  3.],
         [ 1.,  2.,  3.],
         [ 1.,  2.,  3.],
         [ 1.,  2.,  3.]],

        [[ 4.,  5.,  6.],
         [ 4.,  5.,  6.],
         [ 4.,  5.,  6.],
         [ 4.,  5.,  6.],
         [ 4.,  5.,  6.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [ 7.,  8.,  9.],
         [ 7.,  8.,  9.],
         [ 7.,  8.,  9.],
         [ 7.,  8.,  9.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [10., 11., 12.],
         [10., 11., 12.],
         [10., 11., 12.],
         [10., 11., 12.],
         [10., 11., 12.]],

        [[13., 14., 15.],
         [13., 14., 15.],
         [13., 14., 15.],
         [13., 14., 15.],
         [13., 14., 15.],
         [13., 14., 15.]]])
Z1's shape: torch.Size([5, 6, 3]) 



Traceback (most recent call last):
  File "D:/Venv/Test/0710Test/Test.py", line 29, in <module>
    Z2 = X*Y
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 1

Process finished with exit code 1

上面的 X 相当于 rays_direction 向量,维度是 [num_rays, 3],表示所有光线的方向向量;
上面的 Y 相当于 z_val 向量,维度是 [num_rays,num_sample], 表示所有光线的采样点的距离。
那么针对每条光线的每个采样点,都需要乘上 X 中的某一行,相当于是 z x ⃗ z\vec{x} zx 的操作。

X:[5,3]
Y[5,6]
X.unsqueeze(1)*Y.unsqueeze(2):[5,6,3]
X.unsqueeze(1): [ 5 , 3 ] → [ 5 , 1 , 3 ] [5,3] \rightarrow [5,1,\mathbf{3}] [5,3][5,1,3]
Y.unsqueeze(2) : [ 5 , 6 ] → [ 5 , 6 , 1 ] [5,6] \rightarrow [5,\mathbf{6},1] [5,6][5,6,1]

在执行 X.unsqueeze(1)*Y.unsqueeze(2): 的时候:
X.unsqueeze(1) 、Y.unsqueeze(2)都会被 broadcast 成 [5,6,3] 的 shape
看上面两行,X.unsqueeze(1)与Y.unsqueeze(2)相乘, 结果 tensor 的 sahpe 的每一个维度的值取它们两个对应值的较大值。即:
dim 1 max ⁡ { 1 , 6 } = 6 \max\{1,6\}=6 max{1,6}=6
dim 2 max ⁡ { 3 , 1 } = 3 \max\{3,1\}=3 max{3,1}=3

理解 unsqueeze 的很好的一个例子

经过 element-wise 乘法后,即 “*” 乘法后,
得到的结果的 shape 和 参与乘法运算的tensor的较大 dim 的 tensor的shape一致
NeRF 里面一个 weight 矩阵,它的 shape 是:N_rays * N_samples
还有一个 RGB tensor,它的shape 是:N_rays * N_samples*3
根据 Volume Rendering,同一个点的 RGB 的系数是一样的
见下述代码:

nRays = 2
nSamples = 3

weight = torch.tensor([[1,2,3],[7,4,8]])

RGB = torch.tensor([
    [[0,3,0],[4,5,6],[7,8,9]],
    [[10,11,18],[8,58,6],[70,82,9]]
])

print("weight's shape",weight.shape)
print("weight",weight)

print("\n", "RGB's shape",RGB.shape)
print("RGB",RGB,"\n")

weight_after_unsqueeze = weight.unsqueeze(-1)
print("weight_after_unsqueeze's shape: ",weight_after_unsqueeze.shape)
print("weight_after_unsqueeze",weight_after_unsqueeze)


res2 = weight_after_unsqueeze*RGB
print("shape: ",res2.shape)
print(res2)

输出:

weight's shape torch.Size([2, 3])
weight tensor([[1, 2, 3],
        [7, 4, 8]])

 RGB's shape torch.Size([2, 3, 3])
RGB tensor([[[ 0,  3,  0],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 11, 18],
         [ 8, 58,  6],
         [70, 82,  9]]]) 

weight_after_unsqueeze's shape:  torch.Size([2, 3, 1])
weight_after_unsqueeze tensor([[[1],
         [2],
         [3]],

        [[7],
         [4],
         [8]]])
shape:  torch.Size([2, 3, 3])
tensor([[[  0,   3,   0],
         [  8,  10,  12],
         [ 21,  24,  27]],

        [[ 70,  77, 126],
         [ 32, 232,  24],
         [560, 656,  72]]])

上述代码中 N_rays 和 N_weights 分别取 2、3.
在 weight 经过 unsqueeze 后,得到的tensor的形式,竖着排的形式,会给人提示,是一个数占了一列。RGB三个分量都会乘上同一个数字,同一列的。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

培之

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值