Pytorch张量计算详解

张量基础运算

张量基础的四则运算加减乘除运算分别对应add(),sub(),mul(),div(),这四种运算函数具有广播机制

import torch

data1 = torch.tensor([1, 2, 3])
print(data1.add(2))
print(data1)

print(data1.sub(1))
print(data1)

print(data1.mul(2))
print(data1)

print(data1.div(2))
print(data1)
# tensor([3, 4, 5])
# tensor([1, 2, 3])
# tensor([0, 1, 2])
# tensor([1, 2, 3])
# tensor([2, 4, 6])
# tensor([1, 2, 3])
# tensor([0.5000, 1.0000, 1.5000])
# tensor([1, 2, 3])

add_(),sub_(),mul_(),div_(),在上面的四种函数上加上_后的区别为会在四则运算后给自身赋值,也就是+=,-=,*=,/=。

import torch

data1 = torch.tensor([1, 2, 3])
data1.add_(2)
print(data1)

data1.sub_(1)
print(data1)

data1.mul_(2)
print(data1)

data1.div_(2)
print(data1)
# tensor([3, 4, 5])
# tensor([2, 3, 4])
# tensor([4, 6, 8])
# Traceback (most recent call last):
#   File "D:\Pythonproject\teach_day_01\demo01.py", line 13, in <module>
#     data1.div_(2)
# RuntimeError: result type Float can't be cast to the desired output type Long

可以注意到在执行div_()的时候出现了报错,原因为div_()默认返回float类型张量,而在执行 div_()后要覆盖data1的值,且需维持原数据类型Long,所以出现了报错,所以需要把原张量改为float类型进行运算

import torch

data1 = torch.tensor([1, 2, 3], dtype=torch.float32)
data1.add_(2)
print(data1)

data1.sub_(1)
print(data1)

data1.mul_(2)
print(data1)

data1.div_(2)
print(data1)
# tensor([3., 4., 5.])
# tensor([2., 3., 4.])
# tensor([4., 6., 8.])
# tensor([2., 3., 4.])

同样在张量为Long时,若将mul_(2)改为mul_(2.),理论上的运算结果也为float类型,所以在进行data1的重赋值会出现类型报错

import torch

data1 = torch.tensor([1, 2, 3])

data1.mul_(2.0)
print(data1)
# Traceback (most recent call last):
#   File "D:\Pythonproject\teach_day_01\demo01.py", line 5, in <module>
#     data1.mul_(2.)
# RuntimeError: result type Float can't be cast to the desired output type Long

这里同样需要把张量改为float类型来解决

import torch

data1 = torch.tensor([1, 2, 3], dtype=torch.float32)

data1.mul_(2.0)
print(data1)
# tensor([2., 4., 6.])

张量矩阵乘法

哈达玛积(Hadamard product)

哈达玛积也就是矩阵对应位置的元素相乘,并具有广播机制,通常可以通过* 和mul()实现

import torch

data1 = torch.randint(0, 10, (2, 3), dtype=torch.float32)
data2 = torch.randint(0, 10, (2, 3), dtype=torch.float32)
print(data1 * data2)
print(torch.mul(data1, data2))
print(data1 * 2)
# tensor([[63., 14., 21.],
#         [ 0.,  9.,  2.]])
# tensor([[63., 14., 21.],
#         [ 0.,  9.,  2.]])
# tensor([[14., 14.,  6.],
#         [18.,  6.,  4.]])
矩阵点乘(内积)

矩阵内积一般有四种实现方式分别为torch.mm(),torch.bmm(),torch.matmul()以及@操作符

torch.mm()

torch.mm需要参与操作的两个矩阵都为二维矩阵

import torch

data1 = torch.randint(0, 10, (2, 3), dtype=torch.float32)
data2 = torch.randint(0, 10, (3, 3), dtype=torch.float32)
print(data1)
print(data2)
print(torch.mm(data1, data2))
# tensor([[2., 3., 9.],
#         [1., 8., 1.]])
# tensor([[4., 0., 3.],
#         [6., 8., 5.],
#         [8., 1., 5.]])
# tensor([[98., 33., 66.],
#         [60., 65., 48.]])
torch.bmm() 

torch.bmm(),也就是batched matrix multiply,需要参与操作的两个矩阵都为三维矩阵,运算方法为在除最高维度外进行内层二维矩阵的点乘运算

import torch

data1 = torch.randint(0, 10, (3, 2, 3), dtype=torch.float32)
data2 = torch.randint(0, 10, (3, 3, 3), dtype=torch.float32)
print(data1)
print(data2)
print(torch.bmm(data1, data2))
# tensor([[[1., 0., 0.],
#          [9., 2., 9.]],
# 
#         [[7., 5., 0.],
#          [8., 8., 8.]],
# 
#         [[8., 1., 6.],
#          [6., 0., 8.]]])
# tensor([[[8., 5., 1.],
#          [1., 8., 1.],
#          [4., 7., 5.]],
# 
#         [[4., 4., 0.],
#          [1., 1., 4.],
#          [3., 1., 5.]],
# 
#         [[3., 5., 9.],
#          [6., 1., 9.],
#          [1., 1., 5.]]])
# tensor([[[  8.,   5.,   1.],
#          [110., 124.,  56.]],
# 
#         [[ 33.,  33.,  20.],
#          [ 64.,  48.,  72.]],
# 
#         [[ 36.,  47., 111.],
#          [ 26.,  38.,  94.]]])
torch.matmul()

torch.matmul()支持高维矩阵与地位矩阵进行点乘运算,原因为其支持矩阵的广播机制运算,低维矩阵会自动升维后与另一矩阵相乘

import torch

data1 = torch.randint(0, 10, (3, 2, 3), dtype=torch.float32)
data2 = torch.randint(0, 10, (3, 3), dtype=torch.float32)
print(data1)
print(data2)
print(torch.matmul(data1, data2))
# tensor([[[4., 0., 2.],
#          [0., 9., 9.]],
# 
#         [[2., 9., 1.],
#          [0., 7., 7.]],
# 
#         [[9., 4., 5.],
#          [7., 9., 4.]]])
# tensor([[8., 3., 9.],
#         [6., 1., 7.],
#         [8., 9., 8.]])
# tensor([[[ 48.,  30.,  52.],
#          [126.,  90., 135.]],
# 
#         [[ 78.,  24.,  89.],
#          [ 98.,  70., 105.]],
# 
#         [[136.,  76., 149.],
#          [142.,  66., 158.]]])
@操作符

在Python3.5以后,python支持了@操作符进行矩阵运算,原理在torch.matmul()在多数情况下是等价的,注意在过低版本的Python版本上无法使用本操作符

import torch

data1 = torch.randint(0, 10, (3, 2, 3), dtype=torch.float32)
data2 = torch.randint(0, 10, (3, 3), dtype=torch.float32)
print(data1)
print(data2)
print(data1 @ data2)
# tensor([[[9., 1., 9.],
#          [4., 6., 0.]],
# 
#         [[6., 4., 7.],
#          [7., 3., 4.]],
# 
#         [[3., 0., 6.],
#          [6., 2., 0.]]])
# tensor([[2., 7., 8.],
#         [6., 2., 3.],
#         [7., 6., 0.]])
# tensor([[[ 87., 119.,  75.],
#          [ 44.,  40.,  50.]],
# 
#         [[ 85.,  92.,  60.],
#          [ 60.,  79.,  65.]],
# 
#         [[ 48.,  57.,  24.],
#          [ 24.,  46.,  54.]]])

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值