![](https://img-blog.csdnimg.cn/6470d5bc3ee046228fc979ae93bf5d85.png)
PyTorch | 拼接与拆分
一、拼接
1. cat 函数
t
o
r
c
h
.
c
a
t
(
i
n
p
u
t
s
,
d
i
m
=
0
)
:
\qquad torch.cat(inputs, dim=0):
torch.cat(inputs,dim=0):在给定维度
d
i
m
dim
dim(默认为0) 上对输入的张量
i
n
p
u
t
s
inputs
inputs 进行拼接操作,在给定
d
i
m
dim
dim 维度以外的维度上形状要保持相同。
2. stack 函数
t
o
r
c
h
.
s
t
a
c
k
(
s
e
q
u
e
n
c
e
,
d
i
m
=
0
)
:
\qquad torch.stack(sequence, dim=0):
torch.stack(sequence,dim=0):沿着一个新维度
d
i
m
dim
dim 对输入张量进行连接,属于扩张再拼接的函数,也就是在新维度上进行堆叠,序列中所有张量都应该为相同的形状。
3. cat & stack 区别
- s t a c k stack stack 函数会创建一个新的维度。
- s t a c k stack stack 函数要求拼接的两个 t e n s o r tensor tensor 形状完全相同。
二、拆分
1. chunk 函数
t
o
r
c
h
.
c
h
u
n
k
(
t
e
n
s
o
r
,
c
h
u
n
k
s
,
d
i
m
=
0
)
:
\qquad torch.chunk(tensor, chunks, dim=0):
torch.chunk(tensor,chunks,dim=0):将
t
e
n
s
o
r
tensor
tensor 在指定维度
d
i
m
dim
dim 上进行分块(个数
c
h
u
n
k
s
chunks
chunks)。如果沿指定维的张量形状大小不能被
c
h
u
n
k
s
chunks
chunks 整分,则最后一个分块会小于其它分块。
2. split 函数
t o r c h . s p l i t ( t e n s o r , s p l i t _ s i z e , d i m = 0 ) : \qquad torch.split(tensor, split\_size, dim=0): torch.split(tensor,split_size,dim=0):将输入张量分割成相等形状的块结构。如果沿指定维的张量形状大小不能被 s p l i t _ s i z e split\_size split_size 整分,则最后一个分块会小于其它分块。
-
s p l i t _ s i z e split\_size split_size 分割方案:
- 第一种分割方案是指定单个分块的形状大小。
- 第二种分割方案是指定 l i s t list list,待分割张量 t e n s o r tensor tensor 将会被分割为 l e n ( l i s t ) len(list) len(list) 份,每一份的大小取决于 l i s t list list 中的元素, l i s t list list 中的数字总和应等于该维度的大小,否则会报错。
3. chunk & split 区别
s
p
l
i
t
\qquad split
split 函数是根据块内数据长度
来拆分,
c
h
u
n
k
chunk
chunk 函数是根据块个数
来拆分。
4. unbind 函数
torch.unbind(input, dim=0)
:移除
i
n
p
u
t
input
input 张量中的指定维度
d
i
m
dim
dim。