基础知识补充
幸运的悦子
这个作者很懒,什么都没留下…
展开
-
torch.stack()
torch.stack() 是 PyTorch 中的一个函数,用于将多个张量沿着新的维度进行堆叠。需要注意的是,torch.stack() 函数要求所有被堆叠的张量的形状必须相同。此外,堆叠后的张量的维度数目将比原先多一维。将 seq 序列中的每个张量视为一行,构建一个新的二维张量(即矩阵),其中第。其中,第一行表示 a 张量中的元素,第二行表示 b 张量中的元素。在指定的 dim 维度上对这些行进行堆叠,从而得到一个新的张量。原创 2023-04-25 21:10:20 · 505 阅读 · 0 评论 -
torch.gather()解释与使用
具体而言,torch.gather(input, dim, index) 的作用是根据 dim 维度上的 index 索引值,从 input 张量中提取对应位置的数据,并组合成一个新的张量返回。其中,output 张量的第一行表示从 input 张量的第 0 行中提取了第 0 列和第 2 列的数据(即 [1, 3]),第二行表示从第 1 行中提取了第 1 列和第 3 列的数据(即 [6, 8]),第三行表示从第 2 行中提取了第 0 列和第 1 列的数据(即 [9, 10])。例如,假设有一个形状为。原创 2023-04-25 20:46:35 · 376 阅读 · 0 评论