torch.flatten(t, start_dim=0, end_dim=-1) 的实现原理如下。假设类型为 torch.tensor 的张量 t 的形状如下所示:(2,4,3,5,6)
,则 orch.flatten(t, 1, 3).shape
的结果为 (2, 60, 6)
。将索引为 start_dim
和 end_dim
之间(包括该位置)的数量相乘,其余位置不变。因为默认 start_dim=0,end_dim=-1,所以 torch.flatten(t)
返回只有一维的数据。
torch flatten 的理解
最新推荐文章于 2024-09-12 19:13:56 发布