torch.norm()函数的用法

目录

一、函数定义

二、代码示例

三、整体代码


一、函数定义

公式:

                                                                     ||x||_{p} = \sqrt[p]{x_{1}^{p} + x_{2}^{p} + \ldots + x_{N}^{p}}

意思就是inputs的一共N维的话对这N个数据p范数,当然这个还是太抽象了,接下来还是看具体的代码~

p指的是求p范数的p值,函数默认p=2,那么就是求2范数

    def norm(self, input, p=2): # real signature unknown; restored from __doc__
        """
        .. function:: norm(input, p=2) -> Tensor
        
        Returns the p-norm of the :attr:`input` tensor.
        
        .. math::
            ||x||_{p} = \sqrt[p]{x_{1}^{p} + x_{2}^{p} + \ldots + x_{N}^{p}}
        
        Args:
            input (Tensor): the input tensor
            p (float, optional): the exponent value in the norm formulation
        Example::
        
            >>> a = torch.randn(1, 3)
            >>> a
            tensor([[-0.5192, -1.0782, -1.0448]])
            >>> torch.norm(a, 3)
            tensor(1.3633)
        
        .. function:: norm(input, p, dim, keepdim=False, out=None) -> Tensor
        
        Returns the p-norm of each row of the :attr:`input` tensor in the given
        dimension :attr:`dim`.
        
        If :attr:`keepdim` is ``True``, the output tensor is of the same size as
        :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
        Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
        in the output tensor having 1 fewer dimension than :attr:`input`.
        
        Args:
            input (Tensor): the input tensor
            p (float):  the exponent value in the norm formulation
            dim (int): the dimension to reduce
            keepdim (bool): whether the output tensor has :attr:`dim` retained or not
            out (Tensor, optional): the output tensor
        
        Example::
        
            >>> a = torch.randn(4, 2)
            >>> a
            tensor([[ 2.1983,  0.4141],
                    [ 0.8734,  1.9710],
                    [-0.7778,  0.7938],
                    [-0.1342,  0.7347]])
            >>> torch.norm(a, 2, 1)
            tensor([ 2.2369,  2.1558,  1.1113,  0.7469])
            >>> torch.norm(a, 0, 1, True)
            tensor([[ 2.],
                    [ 2.],
                    [ 2.],
                    [ 2.]])
        """
        pass

二、代码示例

输入代码

import torch

rectangle_height = 3
rectangle_width = 4
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):
    for j in range(rectangle_width):
        inputs[i][j] = (i + 1) * (j + 1)
print(inputs)

得到一个3×4矩阵,如下

tensor([[ 1.,  2.,  3.,  4.],
        [ 2.,  4.,  6.,  8.],
        [ 3.,  6.,  9., 12.]])

接着我们分别对其分别求2范数

inputs1 = torch.norm(inputs, p=2, dim=1, keepdim=True)
print(inputs1)
inputs2 = torch.norm(inputs, p=2, dim=0, keepdim=True)
print(inputs2)

结果分别为

tensor([[ 5.4772],
        [10.9545],
        [16.4317]])
tensor([[ 3.7417,  7.4833, 11.2250, 14.9666]])

怎么来的?

inputs1:(p = 2,dim = 1)每行每一列数据进行2范数运算

5.4772 = \sqrt[2]{1^{2} + 2^{2} + 3^{2} + 4^{2}}

10.9545 = \sqrt[2]{2^{2} + 4^{2} + 6^{2} + 8^{2}}

15.4317 = \sqrt[2]{3^{2} + 6^{2} + 9^{2} + 12^{2}}

inputs2:(p = 2,dim = 0)每列每一行数据进行2范数运算

3.7417= \sqrt[2]{1^{2} + 2^{2} + 3^{2}}

7.4833 = \sqrt[2]{2^{2} + 4^{2} + 6^{2}}

11.2250 = \sqrt[2]{3^{2} + 6^{2} +9^{2}}

14.9666 = \sqrt[2]{4^{2} + 8^{2} + 12^{2}}


关注keepdim = False这个参数

inputs3 = inputs.norm(p=2, dim=1, keepdim=False)
print(inputs3)

inputs3

tensor([ 5.4772, 10.9545, 16.4317])

输出inputs1inputs3shape

print(inputs1.shape)
print(inputs3.shape)
torch.Size([3, 1])
torch.Size([3])

可以看到inputs3少了一维,其实就是dim=1(求范数)那一维(列)少了,因为从4列变成1列,就是3行中求每一行的2范数,就剩1列了,不保持这一维不会对数据产生影响

或者也可以这么理解,就是数据每个数据有没有用[]扩起来。

keepdim = True[]扩起来;

keepdim = False不用[]括起来~;


不写keepdim,则默认不保留dim的那个维度

inputs4 = torch.norm(inputs, p=2, dim=1)
print(inputs4)
tensor([ 5.4772, 10.9545, 16.4317])

不写dim,则计算Tensor中所有元素的2范数

inputs5 = torch.norm(inputs, p=2)
print(inputs5)
tensor(20.4939)

等价于这句话

inputs6 = inputs.pow(2).sum().sqrt()
print(inputs6)
tensor(20.4939)

20.4939 = \sqrt[2]{1^{2} + 2^{2} + 3^{2} + 4^{2} + 2^{2} + 4^{2} + 6^{2} + 8^{2} + 3^{2} + 6^{2} + 9^{2} + 12^{2}}


总之,norm操作后dim这一维变为1或者消失


三、整体代码

"""
@author:nickhuang1996
"""

import torch

rectangle_height = 3
rectangle_width = 4
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):
    for j in range(rectangle_width):
        inputs[i][j] = (i + 1) * (j + 1)
print(inputs)

inputs1 = torch.norm(inputs, p=2, dim=1, keepdim=True)
print(inputs1)
inputs2 = torch.norm(inputs, p=2, dim=0, keepdim=True)
print(inputs2)

inputs3 = inputs.norm(p=2, dim=1, keepdim=False)
print(inputs3)

print(inputs1.shape)
print(inputs3.shape)


inputs4 = torch.norm(inputs, p=2, dim=1)
print(inputs4)

inputs5 = torch.norm(inputs, p=2)
print(inputs5)
inputs6 = inputs.pow(2).sum().sqrt()
print(inputs6)
发布了127 篇原创文章 · 获赞 979 · 访问量 142万+

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览