pytorch中dim的含义及相关做法

75 篇文章 2 订阅

1. DIM的表示

我们在pytorch中经常看到函数要设置dim=0,1,-1之类的,我们经常不知道是怎样设置,在测试中有如下表示;
在这里插入图片描述

  • dim=0;举例求和:torch.sum(x,dim=0)举例

    [ [ 0. , 1. , 2. , 3. ] , [ 4. , 5. , 6. , 7. ] , [ 8. , 9. , 10. , 11. ] ] + [ [ 12. , 13. , 14. , 15. ] , [ 16. , 17. , 18. , 19. ] , [ 20. , 21. , 22. , 23. ] ] [ [ 0., 1., 2., 3.], [ 4., 5., 6., 7.],[ 8., 9., 10., 11.]]+[[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]] [[0.,1.,2.,3.],[4.,5.,6.,7.],[8.,9.,10.,11.]]+[[12.,13.,14.,15.],[16.,17.,18.,19.],[20.,21.,22.,23.]]
    t o r c h . s u m ( x , d i m = 0 ) = [ [ 12. , 14. , 16. , 18. ] , [ 20. , 22. , 24. , 26. ] , [ 28. , 30. , 32. , 34. ] ] torch.sum(x,dim=0)=[[12., 14., 16., 18.], [20., 22., 24., 26.], [28., 30., 32., 34.]] torch.sum(x,dim=0)=[[12.,14.,16.,18.],[20.,22.,24.,26.],[28.,30.,32.,34.]]

  • dim=1,举例求和:torch.sum(x,dim=1)举例
    [ 0. , 1. , 2. , 3. ] + [ 4. , 5. , 6. , 7. ] + [ 8. , 9. , 10. , 11. ] = [ 12. , 15. , 18. , 21. ] [ 0., 1., 2., 3.]+[ 4., 5., 6., 7.]+[ 8., 9., 10., 11.]=[12.,15.,18.,21.] [0.,1.,2.,3.]+[4.,5.,6.,7.]+[8.,9.,10.,11.]=[12.,15.,18.,21.]
    [ 12. , 13. , 14. , 15. ] + [ 16. , 17. , 18. , 19. ] + [ 20. , 21. , 22. , 22. ] = [ 48. , 51. , 54. , 57. ] [ 12., 13., 14., 15.]+[ 16., 17., 18., 19.]+[ 20., 21., 22., 22.]=[48., 51., 54., 57.] [12.,13.,14.,15.]+[16.,17.,18.,19.]+[20.,21.,22.,22.]=[48.,51.,54.,57.]
    t o r c h . s u m ( x , d i m = 1 ) = [ [ 12. , 15. , 18. , 21. ] , [ 48. , 51. , 54. , 57. ] ] torch.sum(x,dim=1)=[[12., 15., 18., 21.],[48., 51., 54., 57.]] torch.sum(x,dim=1)=[[12.,15.,18.,21.],[48.,51.,54.,57.]]

  • dim=2,举例求和:torch.sum(x,dim=2)举例
    0 + 1 + 2 + 3 = 6 ; 4 + 5 + 6 + 7 = 22 ; 8 + 9 + 10 + 11 = 38 ; 0+1+2+3=6;4+5+6+7=22;8+9+10+11=38; 0+1+2+3=6;4+5+6+7=22;8+9+10+11=38;
    12 + 13 + 14 + 15 = 54 ; 16 + 17 + 18 + 19 = 70 ; 20 + 21 + 22 + 23 = 86 ; 12+13+14+15=54;16+17+18+19=70;20+21+22+23=86; 12+13+14+15=54;16+17+18+19=70;20+21+22+23=86;
    t o r c h . s u m ( x , d i m = 2 ) = [ [ 6. , 22. , 38. ] , [ 54. , 70. , 86. ] ] torch.sum(x,dim=2)=[[ 6., 22., 38.],[54., 70., 86.]] torch.sum(x,dim=2)=[[6.,22.,38.],[54.,70.,86.]]

2. 代码

我们在用dim=0,1,2的时候可以进行从外到里一步步拆分即可;

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: dim_test
# @Create time: 2022/2/18 13:15
import torch
from torch import nn


input = torch.ones((2,3,4))
softmax_dim_0 = nn.Softmax(dim=0)
softmax_dim_1 = nn.Softmax(dim=1)
softmax_dim_2 = nn.Softmax(dim=2)

output_dim_0 = softmax_dim_0(input)
output_dim_1 = softmax_dim_1(input)
output_dim_2 = softmax_dim_2(input)
print(f"input={input}")
print(f"input.shape={input.shape}")
print(f"output_dim_0={output_dim_0}")
print(f"output_dim_0.shape={output_dim_0.shape}")
print(f"output_dim_1={output_dim_1}")
print(f"output_dim_1.shape={output_dim_1.shape}")
print(f"output_dim_2={output_dim_2}")
print(f"output_dim_2.shape={output_dim_2.shape}")

x = torch.arange(12,dtype=torch.float32).reshape(3,4)
x_sum_0 = torch.sum(x,dim=0)
x_sum_1 = torch.sum(x,dim=1)
x_argmax_0 = torch.argmax(x,dim=0)
x_argmax_1 = torch.argmax(x,dim=1)
print(f"x={x}")
print(f"x_sum_0={x_sum_0}")
print(f"x_sum_1={x_sum_1}")
print(f"x_argmax_0={x_argmax_0}")
print(f"x_argmax_1={x_argmax_1}")

x = torch.arange(24,dtype=torch.float32).reshape(2,3,4)
print(f"x={x}")
x_sum_dim_0 = torch.sum(x, dim=0)
x_sum_dim_1 = torch.sum(x,dim=1)
x_sum_dim_2 = torch.sum(x,dim=2)
print(f"x_sum_dim_0={x_sum_dim_0}")
print(f"x_sum_dim_0.shape={x_sum_dim_0.shape}")
print(f"x_sum_dim_1={x_sum_dim_1}")
print(f"x_sum_dim_1.shape={x_sum_dim_1.shape}")
print(f"x_sum_dim_2={x_sum_dim_2}")
print(f"x_sum_dim_2.shape={x_sum_dim_2.shape}")

input = torch.ones(3,4)
print(f"input={input}")
print(f"torch.sum(input,dim=0)={torch.sum(input,dim=0)}")
print(f"torch.sum(input,dim=1)={torch.sum(input,dim=1)}")
  • 结果:
input=tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
input.shape=torch.Size([2, 3, 4])
output_dim_0=tensor([[[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]]])
output_dim_0.shape=torch.Size([2, 3, 4])
output_dim_1=tensor([[[0.3333, 0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333, 0.3333]],

        [[0.3333, 0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333, 0.3333],
         [0.3333, 0.3333, 0.3333, 0.3333]]])
output_dim_1.shape=torch.Size([2, 3, 4])
output_dim_2=tensor([[[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]]])
output_dim_2.shape=torch.Size([2, 3, 4])
x=tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
x_sum_0=tensor([12., 15., 18., 21.])
x_sum_1=tensor([ 6., 22., 38.])
x_argmax_0=tensor([2, 2, 2, 2])
x_argmax_1=tensor([3, 3, 3])
x=tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
x_sum_dim_0=tensor([[12., 14., 16., 18.],
        [20., 22., 24., 26.],
        [28., 30., 32., 34.]])
x_sum_dim_0.shape=torch.Size([3, 4])
x_sum_dim_1=tensor([[12., 15., 18., 21.],
        [48., 51., 54., 57.]])
x_sum_dim_1.shape=torch.Size([2, 4])
x_sum_dim_2=tensor([[ 6., 22., 38.],
        [54., 70., 86.]])
x_sum_dim_2.shape=torch.Size([2, 3])
input=tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
torch.sum(input,dim=0)=tensor([3., 3., 3., 3.])
torch.sum(input,dim=1)=tensor([4., 4., 4.])
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值