torch.cat 函数的使用
torch.cat 函数可以将两个张量根据指定的维度拼接起来.
import torch
def test():
data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)
print('-' * 50)
# 1. 按0维度拼接
new_data = torch.cat([data1, data2], dim=0)
print(new_data.shape)
print('-' * 50)
# 2. 按1维度拼接
new_data = torch.cat([data1, data2], dim=1)
print(new_data.shape)
# 3. 按2维度拼接
new_data = torch.cat([data1, data2], dim=2)
print(new_data)
if __name__ == '__main__':
test()
输出结果:
tensor([[[6, 8, 3, 5],
[1, 1, 3, 8],
[9, 0, 4, 4],
[1, 4, 7, 0],
[5, 1, 4, 8]],
[[0, 1, 4, 4],
[4, 1, 8, 7],
[5, 2, 6, 6],
[2, 6, 1, 6],
[0, 7, 8, 9]],
[[0, 6, 8, 8],
[5, 4, 5, 8],
[3, 5, 5, 9],
[3, 5, 2, 4],
[3, 8, 1, 1]]])
tensor([[[4, 6, 8, 1],
[0, 1, 8, 2],
[4, 9, 9, 8],
[5, 1, 5, 9],
[9, 4, 3, 0]],
[[7, 6, 3, 3],
[4, 3, 3, 2],
[2, 1, 1, 1],
[3, 0, 8, 2],
[8, 6, 6, 5]],
[[0, 7, 2, 4],
[4, 3, 8, 3],
[4, 2, 1, 9],
[4, 2, 8, 9],
[3, 7, 0, 8]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
tensor([[[6, 8, 3, 5, 4, 6, 8, 1],
[1, 1, 3, 8, 0, 1, 8, 2],
[9, 0, 4, 4, 4, 9, 9, 8],
[1, 4, 7, 0, 5, 1, 5, 9],
[5, 1, 4, 8, 9, 4, 3, 0]],
[[0, 1, 4, 4, 7, 6, 3, 3],
[4, 1, 8, 7, 4, 3, 3, 2],
[5, 2, 6, 6, 2, 1, 1, 1],
[2, 6, 1, 6, 3, 0, 8, 2],
[0, 7, 8, 9, 8, 6, 6, 5]],
[[0, 6, 8, 8, 0, 7, 2, 4],
[5, 4, 5, 8, 4, 3, 8, 3],
[3, 5, 5, 9, 4, 2, 1, 9],
[3, 5, 2, 4, 4, 2, 8, 9],
[3, 8, 1, 1, 3, 7, 0, 8]]])
torch.stack 函数的使用
torch.stack 函数可以将两个张量根据指定的维度叠加起来.
import torch
def test():
data1= torch.randint(0, 10, [2, 3])
data2= torch.randint(0, 10, [2, 3])
print(data1)
print(data2)
new_data = torch.stack([data1, data2], dim=0)
print(new_data.shape)
new_data = torch.stack([data1, data2], dim=1)
print(new_data.shape)
new_data = torch.stack([data1, data2], dim=2)
print(new_data)
if __name__ == '__main__':
test()
输出结果:
tensor([[5, 8, 7],
[6, 0, 6]])
tensor([[5, 8, 0],
[9, 0, 1]])
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])
tensor([[[5, 5],
[8, 8],
[7, 0]],
[[6, 9],
[0, 0],
[6, 1]]])
简单行、列索引
准备数据
import torch
data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)
输出结果:
tensor([[0, 7, 6, 5, 9],
[6, 8, 3, 1, 0],
[6, 3, 8, 7, 3],
[4, 9, 5, 3, 1]])
# 1. 简单行、列索引
def test01():
print(data[0])
print(data[:, 0])
print('-' * 50)
if __name__ == '__main__':
test01()
输出结果:
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
列表索引
# 2. 列表索引
def test02():
# 返回 (0, 1)、(1, 2) 两个位置的元素
print(data[[0, 1], [1, 2]])
print('-' * 50)
# 返回 0、1 行的 1、2 列共4个元素
print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
test02()
输出结果:
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
[8, 3]])
范围索引
# 3. 范围索引
def test03():
# 前3行的前2列数据
print(data[:3, :2])
# 第2行到最后的前2列数据
print(data[2:, :2])
if __name__ == '__main__':
test03()
输出结果:
tensor([[0, 7],
[6, 8],
[6, 3]])
tensor([[6, 3],
[4, 9]])
布尔索引
# 布尔索引
def test():
# 第三列大于5的行数据
print(data[data[:, 2] > 5])
# 第二行大于5的列数据
print(data[:, data[1] > 5])
if __name__ == '__main__':
test04()
输出结果:
tensor([[0, 7, 6, 5, 9],
[6, 3, 8, 7, 3]])
tensor([[0, 7],
[6, 8],
[6, 3],
[4, 9]])
多维索引
# 多维索引
def test05():
data = torch.randint(0, 10, [3, 4, 5])
print(data)
print('-' * 50)
print(data[0, :, :])
print(data[:, 0, :])
print(data[:, :, 0])
if __name__ == '__main__':
test05()
输出结果:
tensor([[[2, 4, 1, 2, 3],
[5, 5, 1, 5, 0],
[1, 4, 5, 3, 8],
[7, 1, 1, 9, 9]],
[[9, 7, 5, 3, 1],
[8, 8, 6, 0, 1],
[6, 9, 0, 2, 1],
[9, 7, 0, 4, 0]],
[[0, 7, 3, 5, 6],
[2, 4, 6, 4, 3],
[2, 0, 3, 7, 9],
[9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
[5, 5, 1, 5, 0],
[1, 4, 5, 3, 8],
[7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
[9, 7, 5, 3, 1],
[0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
[9, 8, 6, 9],
[0, 2, 2, 9]])