拼接cat和叠加stack:
cat是直接把两个张量里面的元素拼接起来:[2,3,4],[2,3,4]的两个张量拼接第3维度即dim=2结果是[2,3,8]
stack是叠加哪个维度的,就哪个维度当成一个元素一样叠加成一个元素;[2,3,4],[2,3,4]的两个张量叠加第3维度即dim=2即1,2维度不动,的第3维度的每一行相互叠加,然后第3维变成[2,4]张量变成[2,3,2,4]
def torch11(): # 拼接函数 cat
data1 = torch.randint(0, 10, [3, 4, 5])# 3个4行5列
data2 = torch.randint(0, 10, [3, 4, 5])
print(data1.shape)
print(data2.shape)
new_data = torch.cat([data1, data2], dim=0) # 拼接 dim表示拼接维度0就是拼接第一位,1就是拼接第二位
print(new_data.shape)
def torch12(): # 叠加函数 stack
data1 = torch.randint(0, 10, [3, 3, 4])
data2 = torch.randint(0, 10, [3, 3, 4])
print(data1.shape)
print(data2.shape)
new_data = torch.stack([data1, data2], dim=2)
print(new_data.shape)
索引部分
def torch13(): # 列表索引
data = torch.randint(10, 20, [4, 5])
print(data)
print(data[:3, 2]) # ,前面表示行后面表示列 :所有 :3前3列
print(data[3][2])
print(data[[[0], [2]], [0, 1, 3]]) # 第0,2行,0,1,3 元素
def torch14(): # 布尔索引
data = torch.randint(10, 20, [4, 5])
print(data)
print(data > 13)
print(data[data > 13]) # 返回大于13的元素
def torch15(): # 多维索引
data = torch.randint(10, 20, [3, 4, 5])
print(data)
print(data[:, 0, :])
张量的形状操作这里介绍几个函数
def torch16(): data = torch.randint(1, 10, [3, 4, 5]) print(data.shape) data = data.reshape(2, 6, 5) # 自定义形状,但是里面元素数量要相同 print(data.shape) print(data.reshape(4, 5, -1).shape) # -1代表不管是多少,由电脑进行运算判断是几 data = torch.transpose(data, 0, 2) # 交换维度,一次只能交换两个 print(data.shape) data = torch.permute(data, [0, 2, 1]) # 交换维度 print(data.shape) def torch17(): data = torch.randint(1, 10, [3, 1, 5, 6, 1]) print(data.shape) data = data.squeeze() # 删除1的维度 print(data.shape) data = data.unsqueeze(-1) # 增加一个1维度 print(data.shape)
torch.Size([3, 4, 5]) torch.Size([2, 6, 5]) torch.Size([4, 5, 3]) torch.Size([5, 6, 2]) torch.Size([5, 2, 6]) _______ torch.Size([3, 1, 5, 6, 1]) torch.Size([3, 5, 6]) torch.Size([3, 5, 6, 1])