张量操作与线性回归
一、张量拼接与切分
1.1 张量拼接
1.1.1 torch.cat(tensor, dim=0,out=None)
功能:将张量按维度dim进行拼接。
tensors:张量序列;
dim:要拼接的维度
dim=0对应tensor(2, 3)里面的2,也就是张量中的行,所以t_0拼接出来的tensor:t_0 (4, 3),即4行3列;dim=1对应tensor(2, 3)里面的3,也就是张量中的列,所以拼接出来的tensor:t_1(2, 6)。
import torch
t = torch.ones((2, 3))
t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t], dim=1)
print(t_0)
print(t_1)
1.1.2 torch.stack(tensors, dim=0, out=None)
功能:在新创建的维度dim上进行拼接。
tensors:张量序列。
dim: 要拼接的维度。
stack 函数中的dim,会在新的维度上拼接张量。例如dim=2时,会得(2, 3, 2) 形状的张量。
import torch
t = torch.ones((2, 3))
t_stack = torch.stack([t, t], dim=2)
print(t_stack)
1.2 张量切分
1.2.1 torch.chunk(input, chunks, dim=0)
功能:将张量按维度dim进行平均切分。
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
input:要切分的张量
chunks:要切分的分数
dim:要切分的维度
import torch
a = torch.ones((2, 5))
list_of_tensors = torch.chunk(a, dim=1, chunks=2)
for idx, t in enumerate(list_of_tensors):
print('第 {} 个张量:{}, shape is {}'.format(idx, t, t.shape))
1.2.12 torch.split(tensor, split_size_or_sections, dim=0)
功能:将张量按维度dim进行平均切分
返回值:张量列表
tensor:要切分的张量
split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
dim:要切分的维度。
import torch
t = torch.ones((2, 5))
list_of_tensors = torch.split(t, 2, dim=1)
for idx, t in enumerate