PyTorch学习笔记(二):索引与切片

Indexing: dim 0 first

a = torch.rand(4,3,28,28)
print(a.dim()) 
print(a.shape)
# 输出
4
torch.Size([4, 3, 28, 28])

print(a[0].shape)           # torch.Size([3, 28, 28])
print(a[0, 0].shape)        # torch.Size([28, 28])
print(a[0, 0, 0].shape)     # torch.Size([28])
print(a[0, 0, 0, 0].shape)  # torch.Size([])

print(a[0, ...].shape)      # torch.Size([3, 28, 28])
print(a[0, ..., 0].shape)   # torch.Size([3, 28])
print(a[..., :2].shape)     # torch.Size([4, 3, 28, 2])

select first/last N

print(a[:2].shape)             # torch.Size([2, 3, 28, 28])
print(a[:2, :1, :, :].shape)   # torch.Size([2, 1, 28, 28])
print(a[1:, :, :, :].shape)    # torch.Size([3, 3, 28, 28])

select by steps

print(a[:, :, 0:28:2, 0:28:2].shape)   # torch.Size([4, 3, 14, 14])
print(a[:, :, ::2, ::2].shape)         # torch.Size([4, 3, 14, 14])

select by specific index

print(a.index_select(0, torch.tensor([0,2])).shape)  # torch.Size([2, 3, 28, 28])
print(a.index_select(1, torch.arange(2)).shape)      # torch.Size([4, 2, 28, 28])

select by mask

a = torch.rand(3,4)
print(a)
# 输出
tensor([[0.9796, 0.9025, 0.2744, 0.4932],
        [0.6778, 0.5818, 0.7009, 0.6437],
        [0.2674, 0.8005, 0.6140, 0.6765]])

mask = torch.tensor([[False, False, False, True],
                     [True, False, False, False],
                     [False, True, False, False]])
a = torch.masked_select(a, mask)
print(a)          # tensor([0.4932, 0.6778, 0.8005])
print(a.shape)    # torch.Size([3])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值