张量操作
一、张量拼接
torch.cat(#将张量按维度dim进行拼接
tensor,#张量序列
dim=0,#要拼接的维度
out=None
)
torch.stack(#在新创建的维度dim上进行拼接
tensor,
dim,
out=None
)
二、张量切分
torch.chunk(#将张量按维度dim进行平均切分,返回值为张量列表,若不能整除最后一份张量小于其他张量
input,#要切分的张量
chunks,#要切分的份数
dim=0#要切分的维度
)
torch.split(#将张量按维度dim进行切分,返回值为张量列表
tensor,
solit_size_or_sections,@为int时,表示每一份的长度,为list时按list元素切分
dim=0
)
三、张量索引
torch.index_select(#维度dim上,按index索引数据,返回值为依索引数据拼接的张量
input,#要索引的张量
dim,#要索引的维度
index,#要索引数据的序号
out=None
)
torch.masked_select(#按mask中的True进行索引
input,#要索引的张量
mask,#与input同形状的布尔类型张量