pytorch学习09:矩阵基本运算

四则运算

import torch

a = torch.tensor([
    [1,2],
    [3,4]
])

b = torch.tensor([
    [10, 20]
])

# 加
print("torch.all(torch.eq(a+b, torch.add(a,b))):",
      torch.all(torch.eq(a+b, torch.add(a,b))))
print("a+b:\n{}\n".format(a+b))

# 减
print("torch.all(torch.eq(a-b, torch.sub(a,b))):",
      torch.all(torch.eq(a-b, torch.sub(a,b))))
print("a*b:\n{}\n".format(a-b))

# 乘(是点乘)
print("torch.all(torch.eq(a*b, torch.mul(a,b))):",
      torch.all(torch.eq(a*b, torch.mul(a,b))))
print("a*b:\n{}\n".format(a*b))

# 除
print("torch.all(torch.eq(a/b, torch.div(a,b))):",
      torch.all(torch.eq(a/b, torch.div(a,b))))
print("a*b:\n{}\n".format(a/b))

请添加图片描述

矩阵相乘

import torch

a = torch.tensor([
    [1],
    [3]
])

b = torch.tensor([
    [10, 20]
])

# mm只能运算至多二维矩阵
print("torch.mm(a, b):\n{}\n".format(torch.mm(a, b)))
# matmul可运算更高维矩阵
print("torch.matmul(a, b):\n{}\n".format(torch.matmul(a, b)))
print("a@b:\n{}\n".format(a@b))

请添加图片描述

大于2维的矩阵相乘

import torch

a1 = torch.rand(4, 3, 28, 64)
b1 = torch.rand(4, 3, 64, 32)

c1 = torch.matmul(a1, b1)
# 对最后两维进行乘法运算
# 可以理解为多个矩阵并行相乘
print("c1.shape: ", c1.shape)

a2 = torch.rand(4, 1, 28, 64)
b2 = torch.rand(4, 3, 64, 32)
c2 = torch.matmul(a2, b2)
# 这里用到了广播机制
print("c2.shape: ", c2.shape)

请添加图片描述

幂运算

import torch

a = torch.tensor([
    [1, 2],
    [3, 4]
])

print("a.pow(2):\n{}\n".format(a.pow(2)))
print("a**2:\n{}\n".format(a**2))

print("a.pow(0.5):\n{}\n".format(a.pow(0.5)))
print("a.sqrt():\n{}\n".format(a.sqrt()))
# 平方根的倒数
print("a.rsqrt():\n{}\n".format(a.rsqrt()))
print("a**0.5:\n{}\n".format(a**0.5))

请添加图片描述

exp log

import torch

a = torch.tensor([
    [1, 2],
    [3, 4]
])

a_exp = torch.exp(a)
# e^x
print("torch.exp(a):\n{}\n".format(a_exp))
# ln x
# 以2为底:log2
# 以10为底:log10
print("torch.log(a_exp):\n{}\n".format(torch.log(a_exp)))

请添加图片描述

近似值

import torch

a = torch.tensor(1.67)

# 向下取整
print("a.floor():", a.floor())
# 向上取整
print("a.ceil():", a.ceil())
# 取整数部分
print("a.trunc():", a.trunc())
# 取小数部分
print("a.frac():", a.frac())
# 四舍五入
print("a.round():", a.round())

请添加图片描述

最大值、最小值、中位数

import torch

a = torch.rand(2,3)*20

print("a:\n{}\n".format(a))

# 最大值
print("a.max(): ", a.max())
# 中位数,偶数时不取平均,取从小到大第 length/2 个
print("a.median(): ", a.median())
# 最小值
print("a.min(): ", a.min())

请添加图片描述

限制区间

import torch

a = torch.rand(2,3)*20

print("a:\n{}\n".format(a))

# clamp(min),当有值小于 min 时,用 min 替换
print("a.clamp(10):\n{}\n".format(a.clamp(10)))
# clamp(min, max),当有值小于 min 时,用 min 替换
# 当有值大于 max 时,用 max 替换
print("a.clamp(5, 10):\n{}\n".format(a.clamp(5, 10)))

请添加图片描述

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值