Pytroch中常用函数

1、torch.empty()

torch.empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) → Tensor
返回一个size大小的tensor,里面的数值是随机的,主要用到size,其它都是可选内容。
szie可以是列表,元组等。

2、torch.rand()

torch.rand(*size,out=None)
返回一个大小为size的张量,内部结果随机。

3、torch.zeros()

torch.zeros(*size,dtype = None)
返回一个形状为为size,类型为dtype,里面的每一个值都是0的tensor.

4、torch.tensor()

torch.tensor(x)
返回一个tensor,里边的值和x的值对应相等,x可以是一个列表,元组等。

5、x.new_ones()

y = x.new_ones(a,b,dtype = None)
其中x是任意一个tensor,返回的y是一个a*b的数据类型为dtype的值为1的tensor,x保持不变。
等价的函数是y = torch.ones(*size)

6、torch.rand_like()

x4 = torch.rand_like(x3,dtype = None)
返回一个维度和x3一样的tensor,其中里边的值是随机的。

7、tensor的加法

tensor的加法就是逐个元素相加。
x和y是两个tensor,以下结果相同
z = x+y;z = torch.add(x,y)
y.add_(x);等价于 y = x + y;

8、x.view

y = x.view(a,b)
返回一个a*b的tensor,其中x也是一个tensor,且含有a*b个元素值。
常用:y = x.view(-1) 返回一个一维的tensor,即只有一行。
x个y共享内存,改变x或者y的值同时也会改变y或者x的值。
如果想让x和y断开联系的话,可以先创造一个x的副本,然后在进行view。
即:y = x.clone().view(a,b)

9、x.item()

可以将只包含一个元素的tensor转化对应的标量。

10、tensor转numpy

b = x.numpy()
b是一个numpy数组,x是一个tensor,x和b共享内存。

11、numpy转tensor。

a是一个numpy数组,b = torch.form_numpy(a),可以把数组a转换为对应的tensor,同样也会共享内存。

12、检查cuda是否可用

torch.cuda.is_available()
device = torch.device(c)  #c可以是'cpu'也可以是"cuda"
x = x.to(device),就是把x送device并返回x。

13、torch.set_default_tensor_type()

torch.set_default_tensor_type()用来设计tensor的默认数据类型
如果tensor的数据类型不一致可能导致无法训练
例:统一设置为double  torch.set_default_tensor_type(torch.DoubleTensor)

14、数据打包

import torch.utils.data as Data
dataset = Data.TensorDataset(arr_x,labels) 
其中arr_x和labels都是tensor.
通过第一个维度对两个tensor同时索引,要求arr_x和labels的第一个维度必须一致。
如:
for x in dataset:
    print(x)
    break
输出:
(tensor([ 0.1078, -0.9351]), tensor(-6.8768))
其中第一个tensor是arr_x中的数据,第二个tensor是labels中的数据。

data_iter = Data.DataLoader(dataset,batch_size,shuffle = True) 
用于分批加载数据,dataset就是用TensorDataset打包好的数据包。每次加载batch_size个数据

15、tensor.max()

_,prediction = x.max(0/1)
如果为0的话,返回列最大值和该最大值对应的列索引下标。
如果为1的话,返回行最大值和该最大值对应的行索引下标。

16、ToTensor

归一化,常常用于对图像的预处理操作。
除此之外还有其他的预处理操作
transform.Normalize() 使图像数值归一化。
transform.CenteCrop() 可以用于图像的裁剪
...

17、torch.cat()

可用于拼接tensor,
poly_tensor = torch.cat(x,y,z)
其中x,y,z的shape[0]要相同。返回一个新的tensor (x,y,z)

18、tensor.permute()

维度转换
x.shape = (a,b,c)
y = x.permute(1,2,0)#对应(b,c,a)
该函数最常用的地方就是图像通道,即cv2和PIL等读入的图像通道都是放在后边,
这时候需要调到前边去,就可以使用该函数啦~

  • 4
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值