文章目录
1. pytorch中 max()、view()、 squeeze()、 unsqueeze() 的区别
链接
查询时间:2019年3月15日10:17:46
- max() 求行或者列的最大值,并返回最值和索引。
- view() 改变数据的形式 长宽。
- squeeze() 压缩矩阵维度为1的这一维度。
- unsqueeze() 增加一个维度1。
2. torchvision.transforms.ToPILImage / ToTensor 通道的变化
原文:Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8
Docs 链接
查询时间:2019年3月18日10:39:09
理解:同样是存储图片数据,tensor和numpy的存储方式却不同,tensor的channel在第一位,numpy的channel却在最后一位。两者的存储方式不同,因此读取方式也不同;但各自的读方式对应各自的存储方式。如果希望最后是tensor数据,那么在传入ToPILImage ->ToTensor之前的数据先转化成tensor的最方便,即输入什么经过ToPILImage ->ToTensor之后依然会是相同的格式(C*W*H的顺序就不会改变) 。
使用:这两个方法有时会一起使用,即先将数据(numpy)转换成图片(transforms.ToPILImage),数据增强(RandomRotation,RandomShift……),之后再转换成数据(transforms.ToTensor),标准化(transforms.Normalize)后进入网络。
3. .repeat(几行,几列)、.view()
- .repeat() 将原来的数据按行按列,复制多少次。数据的大小改变。
- .view() 将原来的数据重新排列(改变shape),数据的大小不改变。
g = 5
b = torch.arange(g).repeat(g, 1).view([1, 1, g, g])
c = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g])
print(b)
print(c)
tensor([[[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]]])
tensor([[[[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]]]])
4. .max(1, keepdim=True)
- 查询时间:2019年5月15日22:44:34
- 1 表示在“行方向”求取最大值,keepdim=True 表示要返回,最大值的索引。
import numpy as np
import torch
targets = np.array([[1, 2, 3, 4],
[0, 1, 0, 0],
[1, 2, 3, 4]])
print(targets.shape)
a, b = torch.from_numpy(targets[:, 1:]).max(1, keepdim=True)
print(a.data)
print(b)
输出结果如下:
(3, 4)
tensor([[4],
[1],
[4]], dtype=torch.int32)
tensor([[2],
[0],
[2]])