代码如下:
import torch if __name__ == '__main__': #一、tensor的索引 a=torch.Tensor([[1,2,3],[0,3,2]]) #1.根据下标索引 print(a[1][2]) print("1.根据下标索引",a[:,1]) #2.选择a中大于1的元素,返回和a相同大小的tensor,符合条件的输出1,否则输出0 s=a>1 print("2.选择a中大于1的元素",s) #3.选择符合条件的元素并返回 print("3.选择符合条件的元素并返回",a[s]) #4.torch.where(condition,x,y),满足condition的位置输出x,否则输出y hh=torch.where(a>1,2,a) print("4.torch.where(condition,x,y)",hh) #5.clamp()函数 t=a.clamp(1,2)#限制最小值为1,最大值为2 print("5.clamp()函数",t) #6.选择非0元素的坐标 g=torch.nonzero(a) print("6.选择非0元素的坐标",g) #二、tensor的变形 #常见的有 #view,resize,reshape #transpose,permute #squeeze,unsqueeze #expand,exoand_as #1.view,resize,reshape a=torch.arange(1,17) print(a.shape) print(a.reshape(2,8).shape)#reshape print(a.resize(4,4).shape)#resize print(a.view(8,2).shape)#view #2.transpose,permute:各维度之间的位置变换 a=torch.Tensor([[1,2,3],[4,5,6]]) b=a.transpose(0,1)#将第0维与第1维的元素进行转置 print(b) c=a.permute(1,0)#按照1,0的维度顺序重新排序 print(c) #3.squeeze,unsqueeze,用来去除size为1的维度 a=torch.arange(4) print(a.shape,a) b=a.unsqueeze(1) print(b.shape,b) #4.expand,exoand_as:采样复制的方式来扩展tensor的维度 a=torch.randn(2,2,1) b=a.expand(2,2,3) print("a",a) print("b",b)