pytorch框架中的两种常用乘法

        在使用pytorch框架复现模型的时候,我们需要再forward()函数中定义模型的逻辑,这时就要对模型参数使用一些运算,这里简单介绍一下pytorch框架下的两种常用的乘法运算。

1 按元素乘

        按元素乘,即张量的对应元素相乘,将每个位置上相乘的结果作为返回值,使用“*”实现。看一下例子:

import torch

a = torch.Tensor([[1, 2], [3, 4]])
print('a', a)
b = torch.Tensor([[5, 6], [7, 8]])
print('b', b)

z = a * b		# 按元素乘
print('z', z)
a tensor([[1., 2.],
        [3., 4.]])
b tensor([[5., 6.],
        [7., 8.]])
z tensor([[ 5., 12.],
        [21., 32.]])

        pytorch中的按元素乘支持广播,看一下例子:

 

import torch

a = torch.rand(2, 10)
print('a', a)
print(a.size())

b = torch.Tensor([[0.2], [0.8]])
print('b', b)
print(b.size())

print('a * b', a * b)		# 按元素乘,广播

 

a tensor([[0.8498, 0.3875, 0.8588, 0.6866, 0.9692, 0.9897, 0.3961, 0.2562, 0.0801,
         0.9533],
        [0.1715, 0.7883, 0.7576, 0.1322, 0.8059, 0.7220, 0.1398, 0.0333, 0.4146,
         0.4013]])
torch.Size([2, 10])
b tensor([[0.2000],
        [0.8000]])
torch.Size([2, 1])
a * b tensor([[0.1700, 0.0775, 0.1718, 0.1373, 0.1938, 0.1979, 0.0792, 0.0512, 0.0160,
         0.1907],
        [0.1372, 0.6306, 0.6061, 0.1057, 0.6447, 0.5776, 0.1119, 0.0266, 0.3317,
         0.3210]])

2 矩阵乘法

        矩阵乘法就是我们熟悉的矩阵的乘法,可以使用“@”实现。看一下例子:

import torch

a = torch.Tensor([[1, 2], [3, 4]])
print('a', a)
b = torch.Tensor([[5, 6], [7, 8]])
print('b', b)

f = a @ b		# 矩阵乘法
print('f', f)
a tensor([[1., 2.],
        [3., 4.]])
b tensor([[5., 6.],
        [7., 8.]])
f tensor([[19., 22.],
        [43., 50.]])

最近在学习pytorch框架,可以加QQ3408649893交流。

©️2020 CSDN 皮肤主题: 精致技术 设计师:CSDN官方博客 返回首页