即看即用 && 比较操作(Comparison Ops) && Pytorch官方文档总结 && 笔记 (七)

(1)本文涉及函数的列表(注释在代码中)

  1. torch.eq                              # 元素是否相等
  2. torch.equal                         # 张量是否相等
  3. torch.ge                              # 元素大小关系 input >= other
  4. torch.gt                               # 元素大小关系 input > other
  5. torch.kthvalue                    # 输入张量 input 指定维上第 k 个最小值
  6. torch.le                               # 元素大小关系 input <= other
  7. torch.lt                                # 元素大小关系 input < other
  8. torch.max                          # 所有元素的最大值 / 指定维度上每行的最大值 
  9. torch.min                            # 所有元素的最小值 / 指定维度上每行的最小值
  10. torch.ne                              # 元素大小关系 input != other
  11. torch.sort                           # 升序排列 / 降序排列
  12. torch.topk                          # 排序靠前的 k 个值

注意:

  • 逐个元素比较大小和整个张量比较大小的概念不同;
  • 指定的维度(dim=0: 沿 y 轴方向)(dim=1: 沿 x 轴方向);
  • 按降序排列:(下降)descending = True;
  • 部分函数的参数 other,可以是和 input 形状相同的张量,也可以是一个实数。

(2)代码示例(含注释)

"""
比较操作 Comparison Ops
"""
import torch

# # 逐个比较元素相等性。
# # 第二个参数可为一个数或与第一个参数同类型形状的张量。
# # 返回值: 一个 torch.ByteTensor 张量,包含了每个位置的比较结果(相等为 true,不等为 False )
# # 返回类型: Tensor
obj1 = torch.eq(torch.tensor([1, 2, 3]), 1)

# # 如果两个张量有完全相同的形状和元素值,则返回 True ,否则 False。
obj2 = torch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4]))  # False

# # 逐元素比较 input 和 other 的大小关系,即是否 input>=other。
# # 如果两个张量有相同的形状和元素值,则返回 True ,否则 False。
# # 第二个参数可以为一个数或与第一个参数相同形状和类型的张量
# # 返回类型: Tensor
obj3 = torch.ge(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))

# # 逐元素比较 input 和 other ,
# # 即是否 input>other 如果两个张量有相同的形状和元素值,则返回 True ,否则 False。
obj4 = torch.gt(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))

# # 取输入张量 input 指定维上第 k 个最小值。如果不指定 dim,则默认为 input 的最后一维。
# # 返回一个元组 (values,indices),其中 indices 是原始输入张量 input 中沿 dim 维的第 k 个最小值下标。
obj5 = torch.kthvalue(torch.arange(1.0, 6.0, 0.5), k=4, dim=0)
# obj5[0]=values=tensor(2.5000)
# obj5[1]=indices=tensor(3)

# # 逐元素比较 input 和 other ,
# # 即是否 input<=other 第二个参数可以为一个数或与第一个参数相同形状和类型的张量
obj6 = torch.le(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))

# # 逐元素比较 input 和 other ,
# # 即是否 input<other 第二个参数可以为一个数或与第一个参数相同形状和类型的张量。
obj7 = torch.lt(torch.tensor([[7.0, 1.0], [1.0, 3.0]]), torch.tensor([[6.0, 1.0], [2.0, 3.0]]))

# # 返回输入张量所有元素的最大值。
obj8 = torch.max(torch.randn(3, 4))

# # 返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。
# # torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
obj9 = torch.max(torch.randn(3, 4), dim=0)
obj10 = torch.max(torch.randn(3, 4), dim=1)

# # 返回输入张量所有元素的最小值。
obj11 = torch.min(torch.randn(3, 4))

# # 返回输入张量给定维度上每行的最大值,并同时返回每个最小值的位置索引。
# # torch.min(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
obj12 = torch.min(torch.randn(3, 4), dim=0)
obj13 = torch.min(torch.randn(3, 4), dim=1)

# # 逐元素比较 input 和 other , 即是否 input!=other。
# # 第二个参数可以为一个数或与第一个参数相同形状和类型的张量。
obj14 = torch.ne(torch.arange(1.0, 3.0), torch.arange(2.0, 4.0))
obj15 = torch.ne(torch.arange(1.0, 3.0), 1.0)

# # 对输入张量 input 沿着指定维按升序排序。
# # 如果不给定 dim,则默认为输入的最后一维。
# # 如果指定参数 descending 为 True,则按降序排序
# # 返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输入中的下标。
obj16 = torch.sort(torch.randn(4, 5), dim=0)
obj17 = torch.sort(torch.randn(4, 5), dim=1)
obj18 = torch.sort(torch.randn(4, 5), dim=1, descending=True)

# # 沿给定 dim 维度返回输入张量 input 中 k 个最大值(即排序靠前的前k个值!!!)。
# # 如果不指定 dim,则默认为 input 的最后一维。
# # 如果为 largest 为 False ,则返回最小的 k 个值。
obj19 = torch.topk(torch.tensor([[1, 2, 3],
                                 [3, 5, 3],
                                 [2, 9, 1]]), k=1, dim=0)
obj20 = torch.topk(torch.tensor([[1, 2, 3],
                                 [3, 5, 3],
                                 [2, 9, 1]]), k=1, dim=1)
obj21 = torch.topk(torch.tensor([[1, 2, 3],
                                 [3, 5, 3],
                                 [2, 9, 1]]), k=2, dim=0)
# input (Tensor) – 输入张量
# k (int) – “top-k”中的 k
# dim (int, optional) – 排序的维
# largest (bool, optional) – 布尔值,控制返回最大或最小值
# sorted (bool, optional) – 布尔值,控制返回值是否排序
# out (tuple, optional) – 可选输出张量 (Tensor, LongTensor) output buffers

print("*"*20, "obj1", "*"*20, "\n", obj1, "\n")
print("*"*20, "obj2", "*"*20, "\n", obj2, "\n")
print("*"*20, "obj3", "*"*20, "\n", obj3, "\n")
print("*"*20, "obj4", "*"*20, "\n", obj4, "\n")
print("*"*20, "obj5", "*"*20, "\n", obj5, "\n")
print("*"*20, "obj6", "*"*20, "\n", obj6, "\n")
print("*"*20, "obj7", "*"*20, "\n", obj7, "\n")
print("*"*20, "obj8", "*"*20, "\n", obj8, "\n")
print("*"*20, "obj9", "*"*20, "\n", obj9, "\n", obj10, "\n")
print("*"*20, "obj11", "*"*20, "\n", obj11, "\n")
print("*"*20, "obj13", "*"*20, "\n", obj12, "\n", obj13, "\n")
print("*"*20, "obj14", "*"*20, "\n", obj14, "\n", obj15, "\n")
print("*"*20, "obj16", "*"*20, "\n", obj16, "\n", obj17, "\n", obj18, "\n")
print("*"*20, "obj19", "*"*20, "\n", obj19, "\n", obj20, "\n", obj21, "\n")

  >>>output

******************** obj1 ******************** 
 tensor([ True, False, False]) 

******************** obj2 ******************** 
 False 

******************** obj3 ******************** 
 tensor([[ True,  True],
              [False,  True]]) 

******************** obj4 ******************** 
 tensor([[ True, False],
              [False, False]]) 

******************** obj5 ******************** 
 torch.return_types.kthvalue(
values=tensor(2.5000),
indices=tensor(3)) 

******************** obj6 ******************** 
 tensor([[False,  True],
               [ True,  True]]) 

******************** obj7 ******************** 
 tensor([[False, False],
               [ True, False]]) 

******************** obj8 ******************** 
 tensor(1.3611) 

******************** obj9 ******************** 
 torch.return_types.max(
                                          values=tensor([0.8743, 1.0526, 0.4736, 1.0062]),
                                          indices=tensor([1, 2, 1, 1])) 
 torch.return_types.max(
                                          values=tensor([0.2985, 1.5670, 1.1388]),
                                          indices=tensor([0, 0, 0])) 

******************** obj11 ******************** 
 tensor(-2.6445) 

******************** obj13 ******************** 
 torch.return_types.min(
                                          values=tensor([-1.1342,  0.2821, -2.2261, -0.6704]),
                                          indices=tensor([2, 1, 1, 1])) 
 torch.return_types.min(
                                          values=tensor([-1.4479, -0.8022, -2.5090]),
                                          indices=tensor([3, 1, 2])) 

******************** obj14 ******************** 
 tensor([True, True]) 
 tensor([False,  True]) 

******************** obj16 ******************** 
 torch.return_types.sort(
values=tensor([[-1.1032, -0.2661, -0.9461, -0.9597, -1.4905],
                          [-0.4305,  0.1166, -0.8521, -0.0462, -0.6387],
                          [-0.0587,  1.2700, -0.1314,  0.4326, -0.5871],
                          [ 0.1834,  1.4156,  0.7330,  0.7903,  1.2259]]),
indices=tensor([[3, 0, 2, 1, 0],
                          [0, 3, 3, 0, 1],
                          [1, 2, 1, 3, 3],
                          [2, 1, 0, 2, 2]])) 
 torch.return_types.sort(
values=tensor([[-1.1859, -1.0903, -0.5112, -0.0466,  1.9069],
                           [-0.9795, -0.3969, -0.2867,  0.6880,  1.2773],
                           [-1.0201, -0.3897, -0.1580,  0.1677,  1.9816],
                           [-0.3943,  0.0321,  0.5373,  0.8734,  1.2815]]),
indices=tensor([[4, 2, 3, 0, 1],
                            [3, 0, 2, 1, 4],
                            [3, 1, 0, 2, 4],
                            [1, 2, 4, 3, 0]])) 
 torch.return_types.sort(
values=tensor([[ 1.0435,  0.0428, -0.0089, -1.1555, -1.6359],
                           [ 2.0980, -0.2942, -0.4056, -1.0058, -1.1898],
                           [ 2.2949,  0.7036,  0.5235, -0.5708, -0.5717],
                           [ 2.4485,  0.4431,  0.1595, -0.3412, -1.3198]]),
indices=tensor([[4, 2, 0, 3, 1],
                            [4, 2, 1, 0, 3],
                            [0, 3, 4, 1, 2],
                            [2, 1, 4, 0, 3]])) 

******************** obj19 ******************** 
 torch.return_types.topk(
                                           values=tensor([[3, 9, 3]]),
                                           indices=tensor([[1, 2, 0]])) 
 torch.return_types.topk(
                                           values=tensor([[3],
                                                                      [5],
                                                                      [9]]),
                                           indices=tensor([[2],
                                                                       [1],
                                                                       [1]])) 
 torch.return_types.topk(
                                           values=tensor([[3, 9, 3],
                                                                      [2, 5, 3]]),
                                           indices=tensor([[1, 2, 0],
                                                                       [2, 1, 1]])) 

专栏链接直达:

https://blog.csdn.net/qq_54185421/category_11794260.html?spm=1001.2014.3001.5482icon-default.png?t=M4ADhttps://blog.csdn.net/qq_54185421/category_11794260.html?spm=1001.2014.3001.5482 >>>如有疑问,欢迎评论区一起探讨

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Flying Bulldog

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值