【Pytorch】矩阵操作api介绍

本文介绍了torch库中mul和mm函数的区别,强调了它们在处理矩阵乘法时的不同条件,并讨论了dim=-1的含义。同时,通过实例展示了torch.cat()的使用,包括按行和按列拼接张量以及常见应用场景,如图像处理中的维度操作。
摘要由CSDN通过智能技术生成

torch.mul() 和 torch.mm() 区别

  • torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵。
  • torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

dim为-1的解释

简单来说就是最后一维。

https://blog.csdn.net/weixin_39331401/article/details/109357343?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-0&spm=1001.2101.3001.4242

torch.cat()函数 

    import torch
    A=torch.ones(2,3) #2x3的张量(矩阵)                                     
    A
    tensor([[ 1.,  1.,  1.],
            [ 1.,  1.,  1.]])
    B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
    B
    tensor([[ 2.,  2.,  2.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.]])
    C=torch.cat((A,B),0)#按维数0(行)拼接
    C
    tensor([[ 1.,  1.,  1.],
            [ 1.,  1.,  1.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.]])
    C.size()
    torch.Size([6, 3])
    D=2*torch.ones(2,4) #2x4的张量(矩阵)
    C=torch.cat((A,D),1)#按维数1(列)拼接
    C
    tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
            [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
    C.size()
    torch.Size([2, 7])
    ```

上面给出了两个张量A和B,分别是2行3列,4行3列。即他们都是2维张量。因为只有两维,这样在用torch.cat拼接的时候就有两种拼接方式:按行拼接和按列拼接。即所谓的维数0和维数1. 

C=torch.cat((A,B),0)就表示按维数0(行)拼接A和B,也就是竖着拼接,A上B下。此时需要注意:列数必须一致,即维数1数值要相同,这里都是3列,方能列对齐。拼接后的C的第0维是两个维数0数值和,即2+4=6.

C=torch.cat((A,B),1)就表示按维数1(列)拼接A和B,也就是横着拼接,A左B右。此时需要注意:行数必须一致,即维数0数值要相同,这里都是2行,方能行对齐。拼接后的C的第1维是两个维数1数值和,即3+4=7.

从2维例子可以看出,使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。

# 实例

在深度学习处理图像时,常用的有3通道的RGB彩色图像及单通道的灰度图。张量size为cxhxw,即通道数x图像高度x图像宽度。在用torch.cat拼接两张图像时一般要求图像大小一致而通道数可不一致,即h和w同,c可不同。当然实际有3种拼接方式,另两种好像不常见。比如经典网络结构:U-Net,里面用到4次torch.cat,其中copy and crop操作就是通过torch.cat来实现的。可以看到通过上采样(up-conv 2x2)将原始图像h和w变为原来2倍,再和左边直接copy过来的同样h,w的图像拼接。这样做,可以有效利用原始结构信息。

# 总结

使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。
 

使用torch.cat()和torch.stack()时出现的小错误

# 错误的代码
a = torch.rand(1,128,128,dtype=float)
b = torch.rand(1,128,128,dtype=float)
c = torch.cat(a,b, dim=0)
print(c.size())

# 改为正确的如下,
a = torch.rand(1,128,128,dtype=float)
b = torch.rand(1,128,128,dtype=float)
c = torch.cat((a,b), dim=0)
print(c.size())

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值