关于Tensor
tensor是Pytorch框架中的核心部分,使用者设计的一系列操作流程如数据预处理,亦或者是模型计算等都可以看做是对tensor进行的一系列操作。
因此,熟练掌握对tensor的基本操作是搭建网络解决实际问题的基础。以下将介绍对tensor的基本操作,以便于初学者快速掌握。
tensor 的创建
pytorch提供4种不同的方式来创建tensor:
- torch.Tensor(inputs)
- 以该种方式创建的tensor占有单独的存储空间,数据类型为pytorch的默认数据类型torch.float32.
- torch.tensor(inputs)
- 以该种方式创建的tensor占有单独的存储空间,数据类型根据inputs而定.
- torch.as_tensor(inputs)
- 以该种方式创建的tensor与inputs共享存储空间,即当inputs的内容修改时,该tensor也会做相应的修改,数据类型根据inputs而定.
- torch.from_numpy(inputs)
- 以该种方式创建的tensor与inputs共享存储空间,即当inputs的内容修改时,该tensor也会做相应的修改,数据类型根据inputs而定,与as_tensor不同,from_numpy只能够接受numpy arr作为输入.
因此推荐使用torch.tensor()或者torch.as_tensor()作为tensor的创建方式。
tensor的基本属性
查看tensor的形状
- tensor.shape
- tensor.size()
重塑tensor的形状
- tensor.reshape()
- tensor.view()
展平tensor
- tensor.flatten(start_dim=) 从start_dim开始之后的维度都会被展平
开辟或删除维度
- tensor.unsqueeze(dim=)从第dim维上开拓一个新的维度
- tensor.squeeze()将所有长度为1的维度删除
将两个不同形状的tensor广播成相同形状
- torch.broadcast_tensor(a,b)
求极值
- tensor.max(dim=) 求最大值
- tensor.min(dim=) 求最小值
- tensor.argmax(dim=) 求最大值对应元素的index
- tensor.argmin(dim=) 求最小值对应元素的index
reduction 操作
- tensor.mean(dim=) 求均值
- tensor.sum(dim=) 求和
- tensor.prod() 求乘积
- tensor.numel() 求元素总数
比较 操作
- torch.le(tensor, value) 小于等于
- torch.lt(tensor, value) 小于
- torch.ge(tensor, value) 大于等于
- torch.gt(tensor, value) 大于
- torch.eq(tensor, value) 等于
- torch.ne(tensor, value) 不等于
拼接操作
- torch.cat((tensor1, tensor2), dim=) 要求tensor1与tensor2的dim存在, 在第dim维上进行拼接
- torch.stack((tensor1,…,tensorn), dim=) dim为新拓展的维度,然后在第dim维上进行拼接. stack = unsqueeze + cat
tensor的其他属性
- tensor.dtype 查看tensor的数据类型
- tensor.device 查看tensor在哪个设备上
- tensor.layout