import torch
A_idx = torch.LongTensor([0, 2]) # the index vector
B = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
C = B.index_select(1, A_idx)
# 1 3
# 4 6
pytorch的tf.slice
最新推荐文章于 2023-03-02 11:06:40 发布
本文详细介绍了如何在PyTorch中使用类似TensorFlow的slice功能来选取和操作张量的部分区域。通过实例代码,展示了如何利用torch.index_select、torch.slice等方法实现等效的切片操作,帮助读者从TensorFlow平滑过渡到PyTorch。
摘要由CSDN通过智能技术生成