pytorch基础

  1. 张量的基本信息

tensor = torch.randn(2,3,4)

print(tensor.type()) # 数据类型 torch.FloatTensor,是一个浮点型的张量

print(tensor.size()) # 张量的shape,是个元组 torch.Size([2, 3, 4])

print(tensor.dim()) # 维度的数量 3

  1. 张量的命名

程序中,一个好的命名可以便于其他人读懂代码,张量的命名也是如此。这样可以方便地使用维度的名字来做索引或其他操作,提高了可读性、易用性,防止程序出错,便于其他人阅读和修改。

在PyTorch 1.3之前,需要使用注释

Tensor[N, C, H, W]

images = torch.randn(32, 3, 56, 56)

images.sum(dim=1)

images.select(dim=1, index=0)

PyTorch 1.3之后

NCHW = [‘N’, ‘C’, ‘H’, ‘W’]

images = torch.randn(32, 3, 56, 56, names=NCHW)

images.sum(‘C’)

images.select(‘C’, index=0)

也可以这么设置

tensor = torch.rand(3,4,1,2,names=(‘C’, ‘N’, ‘H’, ‘W’))

使用align_to可以对维度方便地排序

tensor = tensor.align_to(‘N’, ‘C’, ‘H’, ‘W’)

4.张量数据类型转换

在Pytorch中,FloatTensor处理速度远远快于DoubleTensor,因此默认采用FloatTensor,也可以通过转换变成其他类型的数据。

设置默认类型

torch.set_default_tensor_type(torch.FloatTensor)

类型转换

tensor = tensor.cuda()

tensor = tensor.cpu()

tensor = tensor.float()

tensor = tensor.long()

4.1 torch.Tensor与np.ndarray转换

除了CharTensor类型外,其他所有CPU上的张量都支持转换为numpy格式,当然也可以再转换回来。

ndarray = tensor.cpu().numpy()

tensor = torch.from_numpy(ndarray).float()

tensor = torch.from_numpy(ndarray.copy()).float() # If ndarray has negative stride.

4.2 torch.tensor与PIL.Image转换

在Pytorch中,张量默认采用[N, C, H, W]的顺序,并且数据范围在[0,1],有时候处理数据时需要进行转置和规范化。

torch.Tensor -> PIL.Image

image = PIL.Image.fromarray(torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())

image = torchvision.transforms.functional.to_pil_image(tensor)

PIL.Image -> torch.Tensor

path = r’./figure.jpg’

tensor = torch.from_numpy(np.asarray(PIL.Image.open(path))).permute(2,0,1).float() / 255

tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path)) # Equivalently way

5.张量的常用操作

5.1 矩阵乘法

Matrix multiplcation: (mn) * (np) * -> (m*p).

result = torch.mm(tensor1, tensor2)

Batch matrix multiplication: (bmn) * (bnp) -> (bmp)

result = torch.bmm(tensor1, tensor2)

Element-wise multiplication.

result = tensor1 * tensor2

5.2 计算两组数据之间的两两欧式距离

dist = torch.sqrt(torch.sum((X1[:,None,:] - X2) ** 2, dim=2))

5.3 张量形变

将卷积层输入全连接层的情况时,通常需要对张量做形变处理如.view()和.reshape()等,但是相比torch.view,torch.reshape可以自动处理输入张量不连续的情况。

tensor = torch.rand(2,3,4)

shape = (6, 4)

tensor = torch.reshape(tensor, shape)

5.4 打乱顺序

tensor = tensor[torch.randperm(tensor.size(0))] # 打乱第一个维度

5.5 水平翻转

Pytorch不支持tensor[::-1]这样的负步长操作,水平翻转可以通过张量索引实现。

假设张量的维度为[N, D, H, W].

tensor = tensor[:, :, :, torch.arange(tensor.size(3) - 1, -1, -1).long()]

5.6 张量复制

tensor.clone()

tensor.detach()

tensor.detach.clone()

5.7 张量拼接

torch.cat和torch.stack的区别在于torch.cat沿着给定的维度拼接,而torch.stack会新增一维。当参数是3个10x5的张量,torch.cat的结果是30x5的张量,而torch.stack的结果是3x10x5的张量。

tensor = torch.cat(list_of_tensors, dim=0)

tensor = torch.stack(list_of_tensors, dim=0)

5.8 将整数标签转为one-hot编码

Pytorch的标记默认从0开始,转换为one-hot编码在数据处理时也经常用到。

tensor = torch.tensor([0, 2, 1, 3])

N = tensor.size(0)

num_classes = 4

one_hot = torch.zeros(N, num_classes).long()

one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())

5.9 得到非零元素

torch.nonzero(tensor) # index of non-zero elements,索引非零元素

torch.nonzero(tensor==0) # index of zero elements,索引零元素

torch.nonzero(tensor).size(0) # number of non-zero elements,非零元素的个数

torch.nonzero(tensor == 0).size(0) # number of zero elements,零元素的个数

5.10 张量扩展

将64512的张量扩展为6451277的张量

tensor = torch.rand(64,512)

torch.reshape(tensor, (64, 512, 1, 1)).expand(64, 512, 7, 7)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值