[Pytorch基础操作] List、Numpy、Tensor操作基础

计算机视觉中一切皆矩阵:

  • 图像:3维数据(彩色)、2维数据(黑白)
  • 视频:4维数据(T,3,H,W)

要想代码写的顺手, l i s t 、 n u m p y 、 t e n s o r 的操作一定要烂熟于心 要想代码写的顺手,list、numpy、tensor的操作一定要烂熟于心 要想代码写的顺手,listnumpytensor的操作一定要烂熟于心

一、list列表

1.1 list创建

list 是Python中最基本的数据结构。序列中的每个元素都分配一个数字(它的位置index),与字符串的索引一样,列表索引从0开始。列表可以进行索引,切片,加,乘,检查成员,截取、组合等。在[]内用逗号分隔开任意类型的值,可以实现索引存取。

  • 直接创建
l = [1,2,3,4,5]
l[0] = 10
print(l[0])

10

  • 列表生成式
l = [e for e in range(10)]
print(l)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

  • 随机生成random.random()用于生成一个0到1之间的随机浮点数
l = [random.random() for i in range(10)]
print(l)

[0.9941637112357136, 0.7179263215012058, 0.41440567906139236, 0.5241162283103614, 0.6378630095667321, 0.11731480610659373, 0.9410968887798861, 0.8591094876738674, 0.9812292614236708, 0.1962255254167019]

1.1 list内置方法

  • 切片[start (开始),stop (停止),step(步长)]:列表切片的方向取决于起始索引、结束索引以及步长,当起始索引在结束索引右边是就是从右往左取值,同理反之。当步长为负数时从start开始索引至stop,起点必须大于终点。
a[:]           # a copy of the whole array
a[start:]      # items start through the rest of the array
a[:stop]       # items from the beginning through stop-1
a[start:stop]  # items start through stop-1
a[start:stop:step] # start through not past stop, by step

a[-1]    # last item in the array
a[-2:]   # last two items in the array
a[:-2]   # every items except the last two items

a[::-1]    # all items in the array, reversed
a[1::-1]   # the first two items, reversed
a[:-3:-1]  # the last two items, reversed
a[-3::-1]  # every items except the last two items, reversed
  • 索引存取[idx]正向取值+反向取值,即可存也可以取list[idx]
L = ['Google', 'Runoob', 'Taobao']
print(L[0], L[-1])  # 读取列表第一个和倒数第一个元素

Google Taobao

  • 拼接++号 用于拼接列表。
l = [1,2,3] + [4,5,6]
print(l)

[1, 2, 3, 4, 5, 6]

  • 重复* :*号 用于重复列表。
l2 = [1]*5
print(l2)

[1, 1, 1, 1, 1]

  • 长度len():列表元素个数len(list)
l = [1,2,3,4,5,6]
print(len(l))
  • 成员运算in和not in:判断某元素十否在list中, e in list
l = [1,2,3,4,5,6]
print(3 in l)
print(6 not in l)

True
False

  • 按下标删除del:按照下标idx 使用 del 语句来删除列表的元素,del list[idx]
l = [1,2,3,4,5,6]
del l[0]
print(l)

[2, 3, 4, 5, 6]

  • 插入insert():对任意位置idx插入元素list.insert(idx,e)
l = [1,2,4]
l.insert(2,3)
print(l)

[1, 2, 3, 4]

  • 追加值append():在列表末尾添加新的对象,list.append(e)
l = [1,2,3,4,5,6]
l.append(7)
print(l)

[1, 2, 3, 4, 5, 6, 7]

  • 弹栈pop()list.pop()默认删除最后一个元素
l = [1,2,3,4]
l.pop()
print(l)

[1, 2, 3]

  • 按值删除remove():从左向右顺序遍历,删除第一个找到的对应值的元素list.remove(e)
l = [1,2,3,3,3]
l.remove(3)
print(l)

[1, 2, 3, 3]

  • 清空clear():清空list内所有元素,list.clear()
l = [1,2,3,3,3]
l.clear()
print(l)

[]

  • 反转reverse():反转列表list.reverse()
l = [1,2,3,4,5]
l.reverse()
print(l)

[5, 4, 3, 2, 1]

  • 排序sort():可以降序或升序排序,list.sort(reverse=Ture/False)
l = [1,2,3,4,5]
l.sort(reverse=True)
print(l)

[5, 4, 3, 2, 1]

  • 查找index():查找对应元素的下标(顺序查找第一个),list.index(e)
l = [1,2,3,3,3]
print(l.index(3))

2

  • 统计个数count():统计对应元素出现次数list.count(e)
l = [1,2,3,3,3]
print(l.count(3))

3

二、numpy矩阵

2.1 创建numpy.ndarray

NumPy 最重要的一个特点是其 N 维数组对象 ndarray,它是一系列同类型数据的集合,以 0 下标为开始进行集合中元素的索引。nd array表示n个维度的数字。

  • 用list创建numpy.array(list)object 是list对象,dtype(可选)数组元素的数据类,copy(可选)对象是否需要复制,order(默认A)创建数组的样式,C为行方向,F为列方向,A为任意方向,subok 默认返回一个与基类类型一致的数组,ndmin 指定生成数组的最小维度
numpy.array(object, dtype = None, copy = True, order = None, subok = False, ndmin = 0)
n = np.array([[1,2,3],[4,5,6]])
print(n)

[[1 2 3]
[4 5 6]]

  • 使用np函数创建:shape可以是元组或列表,dtype表示元素类型(默认numpy.float64)。
np.random.random(size=(3,4))  # 生成size维度的多维矩阵,元素是0-1间的随机数
np.random.random(d1,d2...dn)  # 生成d1xd2xd3..xdn维度的矩阵,元素是0-1间的随机数,如random(34)生成34列数组

np.random.randn(d1,d2...dn)  # 生成d1xd2xd3..xdn维度的矩阵,元素服从标准正态分布~N(0,1),如randn(34)生成34列数组
np.random.normal(loc,scale,size=(3,4))  # 生成size形状的矩阵,元素服从整体分布~N(loc均值,scale标准差)

np.random.randint(low=0,high=255,size=[3,64,64])  # 生成形状size的多维数组,整数元素在[low,high]之间随机取值,size默认是1个整数

np.arange(start,end,stop)  # 创建[start,end)步长step的序列,start默认为0,step默认为1,和list的range相同
np.linspace(start,end,num)  # 从start到end共1xnum个元素的等差数列

np.ones(shape, dtype)  # 全1矩阵
np.zeros(shape, dtype)  # 全0矩阵
np.full(shape, full_value, dtype)  # 全full_value矩阵
np.eyes(N, dtype)  # N阶单位方阵

2.2 ndarray常见属性

  • ndim:维度
  • shape:形状
  • size:元素个数
  • dtype:元素类型
img = plt.imread('1.jpg')
print(img.ndim)
print(img.shape)
print(img.size)
print(img.dtype)

3
(1630, 1080, 3)
5281200
uint8

2.2 ndarray内置方法

  • 索引操作[idx]:与列表list完全相同,
l = np.random.randint(1,10,[3,5])
print(l[0])

[2 3 3 4 6]

  • 切片[start (开始),stop (停止),step(步长)]:同list切片操作,当步长为负数时从start开始索引至stop,起点必须大于终点。
a[:]           # a copy of the whole array
a[start:]      # items start through the rest of the array
a[:stop]       # items from the beginning through stop-1
a[start:stop]  # items start through stop-1
a[start:stop:step] # start through not past stop, by step

a[-1]    # last item in the array
a[-2:]   # last two items in the array
a[:-2]   # every items except the last two items

a[::-1]    # all items in the array, reversed
a[1::-1]   # the first two items, reversed
a[:-3:-1]  # the last two items, reversed
a[-3::-1]  # every items except the last two items, reversed
l = np.random.randint(1,10,[3,4])
print(l[1:8:2,:])

[[5 2 3 8]]

例如对原图像进行行反转(上下反转)、列反转(左右反转)、颜色反转(RGB-BGR)
在这里插入图片描述

# 读取图像
img = cv2.imread('test.jpg')
print(img.shape)  # (64, 64, 3)
img = img[::-1,::-1,::-1]
# 显示图像
cv2.imshow('Image', img)
cv2.waitKey(0)
cv2.destroyAllWindows()

在这里插入图片描述

  • 变形reshape()n.reshape(new_shape),对np数组直接用reshape变形,新形状new_shape可以是列表或元组,要保证old_shape所有维度乘积==new_shape所有维度乘积,-1表示自动计算剩下的维度。
img = cv2.imread('test.jpg')
print(img.shape)  # (64, 64, 3)
img = np.reshape(img, (4,4,-1)) 
print(img.shape)  # (4, 4, 768)
  • 拼接concatenate():拼接的数组维度必须相同,可以通过axis参数改变拼接方向np.concatenate((n1,n1), axis=dim_idx),axis表示合并第几个维度
n1 = np.random.randint(1,100,[3,4])  # 3行4列
n2 = np.random.randint(1,100,[3,4])
print(np.concatenate((n1,n1),axis=0).shape) # 上下合并(行合并)合并第一个维度axis=0
print(np.concatenate((n1,n1),axis=1).shape) # 左右合并(列合并)合并第二个维度axis=1

print(np.hstack((n1,n2)).shape)  # 垂直合并
print(np.vstack((n1,n2)).shape)  # 水平合并

(6, 4)
(3, 8)
(3, 8)
(6, 4)

  • 拆分split()np.split(n, (1,2,1), axis=1)指定axis=1第二维度拆分,指定拆分比例 (1,2,1)
n = np.array([[71,44,92,63],[1,19,70,23],[19,73,24,79]])
print(np.vsplit(n, 3))  # 垂直拆分,均匀按行拆分3份
print(np.hsplit(n, 4))  # 水平拆分,均匀按列拆分4份
print(np.split(n,(1,2,1),axis=1))  # 指定axis维度拆分,指定拆分比例

[array([[71, 44, 92, 63]]), array([[ 1, 19, 70, 23]]), array([[19, 73, 24, 79]])]
[array([[71],
[ 1],
[19]]), array([[44],
[19],
[73]]), array([[92],
[70],
[24]]), array([[63],
[23],
[79]])]
[array([[71],
[ 1],
[19]]), array([[44],
[19],
[73]]), array([], shape=(3, 0), dtype=int32), array([[44, 92, 63],
[19, 70, 23],
[73, 24, 79]])]

例如,拆分出RGB图像的某一通道灰度图

# 读取图像
img = cv2.imread("test.jpg")
img, _, _ = np.split(img,3,axis=2)
# 显示图像
img = cv2.convertScaleAbs(img)
cv2.imshow('Image', img)
cv2.waitKey(0)
cv2.destroyAllWindows()

在这里插入图片描述

  • 赋值= 与 拷贝copy():赋值:用的是同一个内存,深拷贝:新建副本。
n1 = np.array([1,2,3,4])
n2 = n1  # 赋值:用的是同一个内存
n3 = np.copy(n1)  # 深拷贝:新建副本
n1[0] = 0
print(n2)
print(n3)

[0, 2, 3, 4]
[1 2 3 4]

  • 统计操作:注意这里的axis不是对该维度求统计值,而是对其他维度求多个该维度的统计值。如axis=0表示每一列的多行求中位数
n = np.array([[1,2,3],[4,5,6],[7,8,9]])
print(np.sum(n, axis=0))  # axis=0表示每一列的多行求和
print(np.min(n,axis=0))  # axis=0表示每一列的多行求最小值
print(np.max(n,axis=0))  # axis=0表示每一列的多行求最大值
print(np.mean(n,axis=0))  # axis=0表示每一列的多行求平均值
print(np.average(n,axis=0))  # axis=0表示每一列的多行求平均值
print(np.median(n,axis=0))  # axis=0表示每一列的多行求中位数
print(np.argmin(n,axis=0))  # axis=0表示每一列的多行求最小值下标
print(np.argmax(n,axis=0))  # axis=0表示每一列的多行求最小值下标
print(np.std(n,axis=0))  # axis=0表示每一列的多行求标准差
print(np.var(n,axis=0))  # axis=0表示每一列的多行求方差

[12 15 18]
[1 2 3]
[7 8 9]
[4. 5. 6.]
[4. 5. 6.]
[4. 5. 6.]
[0 0 0]
[2 2 2]
[2.44948974 2.44948974 2.44948974]
[6. 6. 6.]

  • 矩阵操作:对矩阵所有元素操作、矩阵间加/减/乘法(乘法要求mxn,nxk)、矩阵转置/逆
n1 = np.array([[1,2,3],[4,5,6],[7,8,9]])
print(n1+10) # 加
print(n1-10) # 减
print(n1*10) # 乘
print(n1/10) # 除
print(n1//2) # 整除
print(n1**2) # 平方
print(n1%2) # 取余

[[11 12 13]
[14 15 16]
[17 18 19]]
[[-9 -8 -7]
[-6 -5 -4]
[-3 -2 -1]]
[[10 20 30]
[40 50 60]
[70 80 90]]
[[0.1 0.2 0.3]
[0.4 0.5 0.6]
[0.7 0.8 0.9]]
[[0 1 1]
[2 2 3]
[3 4 4]]
[[ 1 4 9]
[16 25 36]
[49 64 81]]
[[1 0 1]
[0 1 0]
[1 0 1]]

n1 = np.array([[1,2,3],[4,5,6],[7,8,9]])
n2 = np.array([[1,2,3],[4,5,6],[7,8,9]])
print(n1+n2)  # 矩阵加法
print(n1-n2)  # 矩阵减法
print(n1.dot(n2))  # 矩阵乘法
print(n1 @ n2)  # 矩阵乘法
print(n1.T)  # 矩阵转置

[[ 2 4 6]
[ 8 10 12]
[14 16 18]]
[[0 0 0]
[0 0 0]
[0 0 0]]
[[ 30 36 42]
[ 66 81 96]
[102 126 150]]
[[ 30 36 42]
[ 66 81 96]
[102 126 150]]

n1 = np.array([[1,2],[-1,-3]])
print(n1.T)  # 矩阵转置
print(np.linalg.inv(n1))  # 矩阵逆
print(np.linalg.det(n1))  # 矩阵行列式

[[ 1 -1]
[ 2 -3]]
[[ 3. 2.]
[-1. -1.]]
-1.0

  • 广播机制:为不同维度的矩阵尽量提供运算的可能性(补充缺失的维度用已有值填充缺失元素
 n1 = np.array([[1,1,1],[2,2,2]])
n2 = np.array([3,3,3])  # 第二列补全为[3,3,3]
print(n1+n2)

[[4 4 4]
[5 5 5]]

  • 数学操作
    在这里插入图片描述

  • 保存np.save("name", n)保存ndarry为.npy文件

np1 = np.array([[1,1,1],[2,2,2]])
np.save('np1',np1)

三、tensor张量

3.1 创建tensor

  • 用list创建
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=DEVICE, requires_grad=True)
print(x)

tensor([[1., 2., 3.],
[4., 5., 6.]], device=‘cuda:0’, requires_grad=True)

  • 用numpy创建:tensor与numpy转换
# 从numpy创建tensor
x = torch.from_numpy(np.array([1, 2]))
print(x)
# 将tensor转换为numpy
x = x.numpy()
print(x)

tensor([1, 2], dtype=torch.int32)
[1 2]

  • 正态/均匀分布创建
# 标准正态分布 形状为[2, 3]
print(torch.randn(2, 3))

# 0-1均匀分布 形状为[2, 3]
x = torch.rand(2, 3)
print(x)

# low-high的整数的均匀分布 形状为[2, 3]
print(torch.randint(low=0, high=1, [2, 3]))

# 0-1均匀分布 和x形状相同的矩阵
y = torch.rand_like(x)
print(y)

tensor([[-0.8816, 0.6829, 0.2136],
[ 0.3556, -0.7659, -0.5187]])
tensor([[0.7780, 0.6701, 0.7642],
[0.2193, 0.4377, 0.5646]])
tensor([[5, 9, 2],
[9, 5, 4]])

  • 填充创建
#全1填充
print(torch.ones(2, 3))
#全0填充
print(torch.zeros(2, 3))
#full填充
print(torch.full([2, 3], 7))
#对角矩阵
print(torch.eye(3, 3))
#创建但不初始化 0
print(torch.empty(2, 3))

tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[7, 7, 7],
[7, 7, 7]])
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
tensor([[0.0000, 1.8750, 0.0000],
[1.8750, 0.0000, 1.8750]])

  • 递增/减数列
#自增数列,0到9,步长是2
print(torch.arange(start=0, end=10, step=2))
#等差数列,0到10,4个数
print(torch.linspace(start=0, end=10, steps=4))

tensor([0, 2, 4, 6, 8])
tensor([ 0.0000, 3.3333, 6.6667, 10.0000])

  • 指定类型创建
#指定类型的创建
print(torch.FloatTensor([1, 2]))
print(torch.LongTensor([1, 2]))
#获取数据类型
print(torch.randn(2, 3).dtype)
#转换数据类型
print(torch.randn(2, 3).to(torch.float64).dtype)

tensor([1., 2.])
tensor([1, 2])
torch.float32
torch.float64

3.2 tensor内置方法

tensor索引

#创建测试数据
#4张图片,3通道,28*28像素
a = torch.randn(4, 3, 28, 28)
a.shape

#直接索引
#查看第0张图片
print(a[0].shape)
#查看第0张图片的第0个通道
print(a[0, 0].shape)
#查看第0张图片的第0个通道的第2行
print(a[0, 0, 2].shape)
#查看第0张图片的第0个通道的第2行第4列
print(a[0, 0, 2, 4].shape)

#切片
#查看0-1张图片
print(a[:2].shape)
#查看0-1张图片的0-1通道
print(a[:2, :2].shape)
#查看0-1张图片的所有通道的倒数5-正数24行
print(a[:2, :, -5:25].shape)
#有间隔的索引
print(a[:, :, :, ::2].shape)
#用...表示多个被省略的:
#取所有图片的第0-2列
print(a[..., :2].shape)
#取所有图片的第0-2行
print(a[..., :2, :].shape)
#取第2张图片
print(a[2, ...].shape)

torch.Size([3, 28, 28])
torch.Size([28, 28])
torch.Size([28])
torch.Size([])
torch.Size([2, 3, 28, 28])
torch.Size([2, 2, 28, 28])
torch.Size([2, 3, 2, 28])
torch.Size([4, 3, 28, 14])
torch.Size([4, 3, 28, 2])
torch.Size([4, 3, 2, 28])
torch.Size([3, 28, 28])

tensor维度变换

view,reshape维度变换、unsqueeze插入维度、squeeze删除维度,只能删除为1的维度、repeat复制维度、维度交换transpose和permute

#view,reshape维度变换
#4张图片,单通道,28*28像素
a = torch.randn(4, 1, 28, 28)
print(a.shape)
#转换为4,784维度,相当于打平了
print(a.reshape(4, 784).shape)
#转换为4,28,28维度,相当于重新展开
print((a.reshape(4, 784)).reshape(4, 28, 28).shape)
#view的用法和reshape是完全一样的')
print(a.view(4, 784).shape)

torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([4, 28, 28])
torch.Size([4, 784])

#unsqueeze插入维度
#2*2的tensor
a = torch.randn(2, 2)
print(a.shape)
#插入维度在第0维
print(a.unsqueeze(0).shape)
#插入维度在倒数第1维
print(a.unsqueeze(-1).shape)

torch.Size([2, 2])
torch.Size([1, 2, 2])
torch.Size([2, 2, 1])

#squeeze删除维度,只能删除为1的维度
#1*2*2*1的tensor
a = torch.randn(1, 2, 2, 1)
print(a.shape)
#删除第0维
print(a.squeeze(0).shape)
#删除倒数第1维
print(a.squeeze(-1).shape)
#删除所有为1的维度
print(a.squeeze().shape)

torch.Size([1, 2, 2, 1])
torch.Size([2, 2, 1])
torch.Size([1, 2, 2])
torch.Size([2, 2])

#repeat复制维度
#分别复制2次和3次
print(torch.randn(2, 2).repeat(2, 3).shape)

torch.Size([4, 6])

#维度交换
#t转置,只能操作2维tensor
print(torch.randn(1, 2).t().shape)
#transpose维度交换,指定要转换的维度,只能两两交换
print(torch.randn(1, 2, 3).transpose(0, 1).shape)
#permute维度交换,输入维度的顺序
print(torch.rand(1, 2, 3).permute(2, 1, 0).shape)

torch.Size([2, 1])
torch.Size([2, 1, 3])
torch.Size([3, 2, 1])

broadcast广播机制

#broadcast
a = torch.randn(2, 3)
b = torch.randn(1, 3)
c = torch.randn(1)
# 形状不匹配时,自动broadcast扩展维度
print((a + b).shape)
print((a + c).shape)
#手动boradcast
print(b.expand_as(a).shape)
print(c.expand_as(a).shape)

torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 3])

tensor拼接和拆分

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
#cat拼接,dim=0是指定要拼接的维度
print(torch.cat([a, b], dim=0).shape)  # 除了拼接的维度,其他维度完全相等
a = torch.rand(4, 32, 8)
b = torch.rand(4, 32, 8)
#stack组合,会创建一个新的维度,用以区分组合后的两个tensor
print(torch.stack([a, b], dim=0).shape)
a = torch.rand(4, 32, 8)
#split拆分,在0维度上拆分,每2个元素1拆
_1, _2 = a.split(2, dim=0)
print(_1.shape)
print(_2.shape)
#split拆分,在0维度上拆分,拆分后长度分别为1,2,1
_1, _2, _3 = a.split([1, 2, 1], dim=0)
print(_1.shape)
print(_2.shape)
print(_3.shape)
a = torch.rand(4, 32, 8)
#chunk拆分,在0维度上拆分,拆成2个
_1, _2 = a.chunk(2, dim=0)
print(_1.shape)
print(_2.shape)

torch.Size([9, 32, 8])
torch.Size([2, 4, 32, 8])
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
torch.Size([1, 32, 8])
torch.Size([2, 32, 8])
torch.Size([1, 32, 8])
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])

数学运算

#测试数据
a = torch.FloatTensor([[0, 1, 2], [3, 4, 5]])
b = torch.FloatTensor([0, 1, 2])
print(a, b)
#四则运算
print(a + b)
print(a - b)
print(a * b)
print(a / b)
#矩阵乘法
print(a @ b)
print(a.matmul(b))
#计算过程
0 * 0 + 1 * 1 + 2 * 2, 0 * 3 + 1 * 4 + 2 * 5
#指数,对数运算
#求指数
print(a**2)
#开根号
print(a**0.5)
#求e的n次方
print(a.exp())
#以e为底,求对数
print(a.log())
#以2为底,求对数
print(a.log2())
#裁剪,限制数据的上下限
a.clamp(2, 4)

#逻辑运算:返回一个形状相同的bool类型张量
#大于
print(a > b)
#小于
print(a < b)
#等于
print(a == b)
#不等于
print(a != b)

#四舍五入
c = torch.FloatTensor([3.14])
#向下取整
print(c.floor())
#向上取整
print(c.ceil())
#四舍五入
print(c.round())

属性统计

#测试数据
a = torch.FloatTensor([[0., 1., 2.], [3., 4., 5.]])
# 求最小(对每列的多行求最小值) 行dim=0
print(a.min(dim=0))
# 求最大(对每列的多行求最大值) 行dim=0
print(a.max(dim=0))
# 求平均(对每列的多行求平均值) 行dim=0
print(a.mean(dim=0))
# 求积(对每列的多行求乘积) 行dim=0
print(a.prod(dim=0))
# 求和(对每列的多行求和) 行dim=0
print(a.sum(dim=0))
# 求最大值下标(对每列的多行求最大值下标) 行dim=0
print(a.argmax(dim=0))
# 求最小值下标(对每列的多行求最小值下标) 行dim=0
print(a.argmin(dim=0))
# 求1范数
print(a.norm(1))
print(a.norm(1, dim=0))  # 对每列的多行求L1范数
# 求2范数
print(a.norm(2))
print(a.norm(2, dim=0))  # 对每列的多行求L2范数
# 求前2个最小值
print(a.topk(2, dim=1, largest=False))
# 求第2个小的值
print(a.kthvalue(2, dim=1))

torch.return_types.min(
values=tensor([0., 1., 2.]),
indices=tensor([0, 0, 0]))
torch.return_types.max(
values=tensor([3., 4., 5.]),
indices=tensor([1, 1, 1]))
tensor([1.5000, 2.5000, 3.5000])
tensor([ 0., 4., 10.])
tensor([3., 5., 7.])
tensor([1, 1, 1])
tensor([0, 0, 0])
tensor(15.)
tensor([3., 5., 7.])
tensor(7.4162)
tensor([3.0000, 4.1231, 5.3852])
torch.return_types.topk(
values=tensor([[0., 1.],
[3., 4.]]),
indices=tensor([[0, 1],
[0, 1]]))
torch.return_types.kthvalue(
values=tensor([1., 4.]),
indices=tensor([1, 1]))

torch.gather

https://zhuanlan.zhihu.com/p/352877584

按照dimindex在原始tensor中取值,输出tensor的shape不会大于原始tensordim指示index代表第几维的index,其余维度自动匹配。

在这里插入图片描述

import torch
a = torch.tensor([[1, 2, 3], 
                  [4, 5, 6]])

a.gather(dim=1, index=torch.tensor([[0, 0, 0],
                                    [2, 1, 0]]))
# dim=1表示index给的是dim=1的坐标,其他dim的坐标默认
''' 
gather 将按照 index 在 tensor a 中取对应坐标的值,补全index:
[[0, 0, 0],
[2, 1, 0]]   =>
                [(0,0), (0,0), (0,0)
                (1,2), (1,1), (1,0)]

结果:
tensor([[1, 1, 1],
        [6, 5, 4]])
'''

更复杂的例子

import torch
a = torch.tensor([[1, 2, 3], 
                  [4, 5, 6],
                  [7, 8, 9]])

如果想选择[[1,3],[5,6]]

'''
需要的完整index:
[[(0,0),(0,2)],
[(1,1),(1,2)]]
可以发现,第一行都在第一行,第二行都在第二行
可以省略dim=0,指定dim=1
[[(0),(2)],
[(1),(2)]]
'''
a.gather(dim=1, index=torch.tensor([[0, 2],
                                    [1, 2],]))

如果想选择[[1,4],[2,8]]

'''
需要的完整index:
[[(0,0),(1,0)],
[(0,1),(2,1)]]
可以发现,第一行都在第一列,第二行都在第二列
可以省略dim=1,指定列维度dim=0
[[0,1], # 第0列
[0,2]]  # 第1列
dim=0需要转置
[[0, 0],
[1, 2]]
'''
a.gather(dim=0, index=torch.tensor([[0, 0],
                                    [1, 2]]))
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值