PyTorch 19. PyTorch中相似操作的区别与联系

view() 和 reshape()

写在开头:
有一篇大佬的总结非常到位:博客

总结

  1. view() 在操作tensor时,需要tensor是内存连续的,而且在进行尺寸变换时,view()操作不会新开辟内存空间。但是要保证tensor连续,对tensor进行tensor.contiguous()时,会开辟新的内存空间,存放内存连续的数据。
  2. reshape()操作,与view()的作用一模一样,但是它比view()更高级,被操作的tensor是内存连续时,直接采用reshape不会开辟新的内存;被操作的tensor不是内存连续时,reshape操作会开辟新的内存,再对tensor进行reshape。
  3. 最后,用reshape操作就完事了

expand()和repeat()

expand()

返回当前张量在某维扩展更大后的张量。扩展(expand)张量不会分配新的内存,只是在存在的张量上创建一个新的视图(view),一个大小等于1的维度扩展到更大的尺寸
例子:

import torch
x = torch.tensor([1, 2, 3])
x.expand(2,3)
tensor([[1, 2, 3],
			[1, 2, 3]])

注意 expand()只能扩展维度为1的维数,维数不为1的部分要保持一致

repeat()

沿着特定的维度重复这个张量,和expand()不同的是,这个函数拷贝张量的数据
例子

import torch

x = torch.tensor([1, 2, 3])
x.repeat(3,2)
tensor([[1, 2, 3, 1, 2, 3],
		[1, 2, 3, 1, 2, 3],
		[1, 2, 3, 1, 2, 3]])
x2 = torch.randn(2, 3, 4)
x2.repeat(2, 1, 3).shape

torch.Tensor([4, 3, 12])

乘法操作

pytorch中的乘法操作有:torch.mm(), torch.bmm(), torch.matmul(), torch.mul(), 运算符,以及torch.einsum()

二维矩阵乘法 torch.mm()

该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
torch.mm(mat1, mat2, out=None), 其中mat1为(nxm),mat2为(mxd),输出维度是(nxd)

三维带batch的矩阵乘法torch.bmm()

该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。

由于神经网络训练一般采用mini-batch,经常输入的三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1为(bxnxm),bmat2为(bxmxd),输出out的维度是(bxnxd)

多维矩阵乘法 torch.matmul()

torch,matmul(input, other, out=None)支持broadcast操作
针对多维数据matmul()乘法,可以认为该matmul()乘法使用两个参数的后两个维度计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别为input->(100x500x99x11)other->(500x11x99)那么我们可以认为该乘法首先进行后两位矩阵乘法得到(99x11)x(11x99)->(99,99),然后分析两个参数的batch size分别为(1000x500)500,可以广播为(1000x500),因此最终输出的维度是(1000x500x99x99)

矩阵逐元素(Element-wise)乘法torch.mul()

函数torch.mul(mat1, other, out=None),其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast即可。

两个运算符@和*

  1. @:矩阵乘法,自动执行合适的矩阵乘法函数
  2. *:elemnet-wise乘法

register_parameter()和parameter()

  1. Parameter()

Parameter是Tensor, 即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度

# 随便定义一个网络
net = nn.Sequential(nn.Linear(4,3), nn.ReLU(), nn.Linear(3,1))
# list让它可以访问
weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad)
  1. register_parameter(name, parameters)
    向建立的网络module添加parameter
    最大的区别:parameter可以通过注册网络时候的name获取
    例子
class Example(nn.Module):
	def __init__(self):
		super(Example, self).__init__()
		self.W1_params = nn.Parameter(torch.rand(2,3))
		self.register_parameter('W2_params', nn.Parameter(torch.rand(2,3)))
	def forward(self, x):
		return x
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值