张量操作
1、张量的操作:拼接、切分、索引和变换
1.1 张量的拼接与切分
拼接
torch.cat() # 不会扩张维度
功能:将张量按维度dim进行拼接
• tensors: 张量序列 • dim : 要拼接的维度
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0) # 行拼接在一起
t_1 = torch.cat([t, t, t], dim=1) # 列拼接在一起
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))
torch.stack() # 创建新的维度
功能:在新创建的维度dim上进行拼接
• tensors:张量序列
• dim :要拼接的维度
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=0) # 在0维新创建维度
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
t = torch.ones((2, 3))
t_stack = torch.stack([t, t], dim=2) # 在新创建的维度2拼接
切分
torch.chunk()
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于 其他张量
• input: 要切分的张量
• chunks : 要切分的份数,向上取整
• dim : 要切分的维度
a = torch.ones((2, 7))
list_of_tensors = torch.chunk(a, dim=1, chunks=3) #在维度1切分3份
for idx, t in enumerate(list_of_tensors):
print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))
torch.split()
功能:将张量按维度dim进行切分
返回值:张量列表
• tensor: 要切分的张量
• split_size_or_sections : 为int时,表示 每一份的长度;为list时,按list元素切分
• dim : 要切分的维度
t = torch.ones((2,