(1)本文涉及函数的列表(注释在代码中)
- torch.eq # 元素是否相等
- torch.equal # 张量是否相等
- torch.ge # 元素大小关系 input >= other
- torch.gt # 元素大小关系 input > other
- torch.kthvalue # 输入张量 input 指定维上第 k 个最小值
- torch.le # 元素大小关系 input <= other
- torch.lt # 元素大小关系 input < other
- torch.max # 所有元素的最大值 / 指定维度上每行的最大值
- torch.min # 所有元素的最小值 / 指定维度上每行的最小值
- torch.ne # 元素大小关系 input != other
- torch.sort # 升序排列 / 降序排列
- torch.topk # 排序靠前的 k 个值
注意:
- 逐个元素比较大小和整个张量比较大小的概念不同;
- 指定的维度(dim=0: 沿 y 轴方向)(dim=1: 沿 x 轴方向);
- 按降序排列:(下降)descending = True;
- 部分函数的参数 other,可以是和 input 形状相同的张量,也可以是一个实数。
(2)代码示例(含注释)
"""
比较操作 Comparison Ops
"""
import torch
# # 逐个比较元素相等性。
# # 第二个参数可为一个数或与第一个参数同类型形状的张量。
# # 返回值: 一个 torch.ByteTensor 张量,包含了每个位置的比较结果(相等为 true,不等为 False )
# # 返回类型: Tensor
obj1 = torch.eq(torch.tensor([1, 2, 3]), 1)
# # 如果两个张量有完全相同的形状和元素值,则返回 True ,否则 False。
obj2 = torch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4])) # False
# # 逐元素比较 input 和 other 的大小关系,即是否 input>=other。
# # 如果两个张量有相同的形状和元素值,则返回 True ,否则 False。
# # 第二个参数可以为一个数或与第一个参数相同形状和类型的张量
# # 返回类型: Tensor
obj3 = torch.ge(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))
# # 逐元素比较 input 和 other ,
# # 即是否 input>other 如果两个张量有相同的形状和元素值,则返回 True ,否则 False。
obj4 = torch.gt(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))
# # 取输入张量 input 指定维上第 k 个最小值。如果不指定 dim,则默认为 input 的最后一维。
# # 返回一个元组 (values,indices),其中 indices 是原始输入张量 input 中沿 dim 维的第 k 个最小值下标。
obj5 = torch.kthvalue(torch.arange(1.0, 6.0, 0.5), k=4, dim=0)
# obj5[0]=values=tensor(2.5000)
# obj5[1]=indices=tensor(3)
# # 逐元素比较 input 和 other ,
# # 即是否 input<=other 第二个参数可以为一个数或与第一个参数相同形状和类型的张量
obj6 = torch.le(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))
# # 逐元素比较 input 和 other ,
# # 即是否 input<other 第二个参数可以为一个数或与第一个参数相同形状和类型的张量。
obj7 = torch.lt(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))
# # 返回输入张量所有元素的最大值。
obj8 = torch.max(torch.randn(3, 4))
# # 返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。
# # torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
obj9 = torch.max(torch.randn(3, 4), dim=0)
obj10 = torch.max(torch.randn(3, 4), dim=1)
# # 返回输入张量所有元素的最小值。
obj11 = torch.min(torch.randn(3, 4))
# # 返回输入张量给定维度上每行的最大值,并同时返回每个最小值的位置索引。
# # torch.min(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
obj12 = torch.min(torch.randn(3, 4), dim=0)
obj13 = torch.min(torch.randn(3, 4), dim=1)
# # 逐元素比较 input 和 other , 即是否 input!=other。
# # 第二个参数可以为一个数或与第一个参数相同形状和类型的张量。
obj14 = torch.ne(torch.arange(1.0, 3.0), torch.arange(2.0, 4.0))
obj15 = torch.ne(torch.arange(1.0, 3.0), 1.0)
# # 对输入张量 input 沿着指定维按升序排序。
# # 如果不给定 dim,则默认为输入的最后一维。
# # 如果指定参数 descending 为 True,则按降序排序
# # 返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输入中的下标。
obj16 = torch.sort(torch.randn(4, 5), dim=0)
obj17 = torch.sort(torch.randn(4, 5), dim=1)
obj18 = torch.sort(torch.randn(4, 5), dim=1, descending=True)
# # 沿给定 dim 维度返回输入张量 input 中 k 个最大值(即排序靠前的前k个值!!!)。
# # 如果不指定 dim,则默认为 input 的最后一维。
# # 如果为 largest 为 False ,则返回最小的 k 个值。
obj19 = torch.topk(torch.tensor([[1, 2, 3],
[3, 5, 3],
[2, 9, 1]]), k=1, dim=0)
obj20 = torch.topk(torch.tensor([[1, 2, 3],
[3, 5, 3],
[2, 9, 1]]), k=1, dim=1)
obj21 = torch.topk(torch.tensor([[1, 2, 3],
[3, 5, 3],
[2, 9, 1]]), k=2, dim=0)
# input (Tensor) – 输入张量
# k (int) – “top-k”中的 k
# dim (int, optional) – 排序的维
# largest (bool, optional) – 布尔值,控制返回最大或最小值
# sorted (bool, optional) – 布尔值,控制返回值是否排序
# out (tuple, optional) – 可选输出张量 (Tensor, LongTensor) output buffers
print("*"*20, "obj1", "*"*20, "\n", obj1, "\n")
print("*"*20, "obj2", "*"*20, "\n", obj2, "\n")
print("*"*20, "obj3", "*"*20, "\n", obj3, "\n")
print("*"*20, "obj4", "*"*20, "\n", obj4, "\n")
print("*"*20, "obj5", "*"*20, "\n", obj5, "\n")
print("*"*20, "obj6", "*"*20, "\n", obj6, "\n")
print("*"*20, "obj7", "*"*20, "\n", obj7, "\n")
print("*"*20, "obj8", "*"*20, "\n", obj8, "\n")
print("*"*20, "obj9", "*"*20, "\n", obj9, "\n", obj10, "\n")
print("*"*20, "obj11", "*"*20, "\n", obj11, "\n")
print("*"*20, "obj13", "*"*20, "\n", obj12, "\n", obj13, "\n")
print("*"*20, "obj14", "*"*20, "\n", obj14, "\n", obj15, "\n")
print("*"*20, "obj16", "*"*20, "\n", obj16, "\n", obj17, "\n", obj18, "\n")
print("*"*20, "obj19", "*"*20, "\n", obj19, "\n", obj20, "\n", obj21, "\n")
>>>output
******************** obj1 ********************
tensor([ True, False, False])******************** obj2 ********************
False******************** obj3 ********************
tensor([[ True, True],
[False, True]])******************** obj4 ********************
tensor([[ True, False],
[False, False]])******************** obj5 ********************
torch.return_types.kthvalue(
values=tensor(2.5000),
indices=tensor(3))******************** obj6 ********************
tensor([[False, True],
[ True, True]])******************** obj7 ********************
tensor([[False, False],
[ True, False]])******************** obj8 ********************
tensor(1.3611)******************** obj9 ********************
torch.return_types.max(
values=tensor([0.8743, 1.0526, 0.4736, 1.0062]),
indices=tensor([1, 2, 1, 1]))
torch.return_types.max(
values=tensor([0.2985, 1.5670, 1.1388]),
indices=tensor([0, 0, 0]))******************** obj11 ********************
tensor(-2.6445)******************** obj13 ********************
torch.return_types.min(
values=tensor([-1.1342, 0.2821, -2.2261, -0.6704]),
indices=tensor([2, 1, 1, 1]))
torch.return_types.min(
values=tensor([-1.4479, -0.8022, -2.5090]),
indices=tensor([3, 1, 2]))******************** obj14 ********************
tensor([True, True])
tensor([False, True])******************** obj16 ********************
torch.return_types.sort(
values=tensor([[-1.1032, -0.2661, -0.9461, -0.9597, -1.4905],
[-0.4305, 0.1166, -0.8521, -0.0462, -0.6387],
[-0.0587, 1.2700, -0.1314, 0.4326, -0.5871],
[ 0.1834, 1.4156, 0.7330, 0.7903, 1.2259]]),
indices=tensor([[3, 0, 2, 1, 0],
[0, 3, 3, 0, 1],
[1, 2, 1, 3, 3],
[2, 1, 0, 2, 2]]))
torch.return_types.sort(
values=tensor([[-1.1859, -1.0903, -0.5112, -0.0466, 1.9069],
[-0.9795, -0.3969, -0.2867, 0.6880, 1.2773],
[-1.0201, -0.3897, -0.1580, 0.1677, 1.9816],
[-0.3943, 0.0321, 0.5373, 0.8734, 1.2815]]),
indices=tensor([[4, 2, 3, 0, 1],
[3, 0, 2, 1, 4],
[3, 1, 0, 2, 4],
[1, 2, 4, 3, 0]]))
torch.return_types.sort(
values=tensor([[ 1.0435, 0.0428, -0.0089, -1.1555, -1.6359],
[ 2.0980, -0.2942, -0.4056, -1.0058, -1.1898],
[ 2.2949, 0.7036, 0.5235, -0.5708, -0.5717],
[ 2.4485, 0.4431, 0.1595, -0.3412, -1.3198]]),
indices=tensor([[4, 2, 0, 3, 1],
[4, 2, 1, 0, 3],
[0, 3, 4, 1, 2],
[2, 1, 4, 0, 3]]))******************** obj19 ********************
torch.return_types.topk(
values=tensor([[3, 9, 3]]),
indices=tensor([[1, 2, 0]]))
torch.return_types.topk(
values=tensor([[3],
[5],
[9]]),
indices=tensor([[2],
[1],
[1]]))
torch.return_types.topk(
values=tensor([[3, 9, 3],
[2, 5, 3]]),
indices=tensor([[1, 2, 0],
[2, 1, 1]]))
专栏链接直达:
https://blog.csdn.net/qq_54185421/category_11794260.html?spm=1001.2014.3001.5482https://blog.csdn.net/qq_54185421/category_11794260.html?spm=1001.2014.3001.5482 >>>如有疑问,欢迎评论区一起探讨