通过类定义一个网络

import torch
from torch import nn

x = torch.ones(2,10)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.out = nn.Linear(10, 1)
    def forward(self,x):
        return self.out(x)

 1. 代码解析

  • 如何定义一个类?self 又是什么东西?
  • 类是如何继承基类的特性的?nn.module 是个什么对象?
  • 为什么会有一个初始化函数 init,初始化函数中的 super().init 函数是做什么用的?是否必须要有?
  • forward 函数有什么作用?该怎样用?
  • 为什么初始化函数前后有两条下划线?forward 前后为什么没有下划线?
  • nn.Linear 函数是干什么用的?
  • 输入输出张量的形状大小都是怎么对应的?
  • 模型内部的网络参数是怎么定义的,如何查看?

1.1 如何定义一个类 

class MLPSimple(nn.Module):
    def __init__(self):
        super().__init__()
        self.out = nn.Linear(10, 1)

a_s = MLPSimple()
a_s

 self 是指向类自身的一个指针,可以通过该指针引用类自身的成员,默认这个参数是每个成员函数的首个输入参数,如果没有self参数,那么定义的函数将无法引用类自身的成员

1.2 如何继承基类的特性?nn.module 是个什么对象?

定义类的时候,将需要被继承的类的名称作为参数传入,如 class MLPSimple(nn.Module) 这样就是定义了一个新的类 MLPSimple ,这个新类继承了 nn.Module 的所有特性。 以下展示再创建一个类,继承刚刚创建的新类 MLPSimple_p。

class MLPSimple_p(MLPSimple):
    def __init__(self):
        super().__init__()
a_s_p = MLPSimple_p()
a_s_p

nn.Module是PyTorch中的一个类,继承自torch.nn的基类,用于定义神经网络模型、提供前向传播过程所需的基本功能和方法。 在PyTorch中,神经网络模型通常是由多个层组成的,每个层都是一个nn.Module实例。通过继承nn.Module类并实现自己的forward方法,可以定义自己的神经网络模,。在神经网络的训练和推理过程中,PyTorch会自动调用nn.Module的forward方法来计算输出。

1.3 为什么会有一个初始化函数 init,初始化函数中的 super().init 函数是做什么用的?是否必须要有? 

与C++不同,我们自己新定义的python类没有显式的构造函数(python 类有自己的构造函数,该构造函数跟 init 函数一样也是个魔法函数),python类的对init函数的调用,可以被看做是类似于C++类调用构造函数类似的过程,当python中通过类创建对象的时候就会调用init函数对对象进初始化,与c++不同的是如果c++继承了基类,那么构造对象的时候会隐式的调用基类的构造方法,这里python却需要显示的主动调用基类初始化方法super().init()对基类的特性进行初始化。这里的显示调用时必须的,如果漏掉会报错。

1.4 forward 函数有什么作用?该怎样用?

python 的类成员方法中有一个非常特殊的函数叫做 call() 函数,这个函数使得实例化的对象自身可以像一个函数一样被调用,如同样实例对象为 a_s_p ,如果这个对象是在C++中,那么这个对象就单纯的是一个对象而已,要想让这个对象处理一些事情,就必须通过对象去调用它自身的一些方法来实现,如a_s_p.func(),但是在python类中,类定义里面有一个特别的函数叫做call函数,这个函数可以使被实例的对象本身像一个函数一样被直接调用,而在pytorch中这个call函数会默认直接调用创建类的forward函数,forward函数会接受所有传递给call函数的参数,call函数本身也会将forward函数的返回结果直接返回,因此就形成了pytorch中这种可以直接通过对象本身来处理数据的现象。

a_s_p(inputs) 隐含的意思就是 a_s_p.call(inputs) , 而 a_s_p.call(inputs) 本身的定义却类似于以下这样:

class MLPSimple(nn.Module):
    def __call__(self,inputs):
        return self.forward(inputs)
    def forward(self,inputs)
        return outputs

1.5 为什么初始化函数前后有两条下划线?forward 前后为什么没有下划线?

函数前后有两条下划线的方法叫做python的魔法函数,魔法函数本身是指的到了特定状况下会自动被调用的函数,因为其自适应性像魔法一样神奇所以被称为魔法函数,没有下划线的函数指的是普通函数,像forward函数的名称是pytorch的保留字,默认被call函数调用,但它仍然跟普通函数一样,没什么特别之处。其他的魔法函数还有如下这些:

  • init():类的初始化方法,在创建类的实例时自动调用。
  • new():类的构造函数,当使用类的构造函数创建新的类实例时自动调用。
  • str():返回对象的字符串表示,当调用print()函数输出对象时自动调用。
  • del():在对象被删除时自动调用。
  • call():当对象被作为函数调用时自动调用。
  • len():返回对象的长度,当使用len()函数调用对象时自动调用。
  • eq():比较两个对象是否相等,当使用==运算符比较两个对象时自动调用。
  • hash():返回对象的哈希值,当使用hash()函数调用对象时自动调用。
  • getitem():当使用方括号运算符[]访问对象的元素时自动调用。
  • setitem():当使用方括号运算符[]修改对象的元素时自动调用。的元素时自动调用。[]修改对的元素时自动调用。

1.6 nn.Linear 函数是干什么用的?

对输入向量进行线性变换的一个网络层类,,与之类似的类还有以下几个类: 'Bilinear(双线性变换' 'Identity(占位符) 'LazyLinear'(系数矩阵尺寸在第一次被调用时候自动初始化,不需要主动指定)

class Linear(Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,device=None, dtype=None) -> None:
        ...

   def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

详细介绍见官方文档:Linear — PyTorch 2.0 documentation

初始化时候指定input_tensor[size_in_0,in_features] \output_tensor[size_out_0,out_features],即指定了线性层的系数矩阵尺寸 weight[in_features,out_features],计算时候 out_tensor = input_tensor * weight = [size_in_0,in_features] * [in_features,out_features] = [size_in_0 ,out_features] 在开头的例子中即为 out_tensor = input_tensor * weight = [2,10] * [10,1] = [2 ,1] 以上计算过程中可以发现,矩阵的最后一个维度是样本的特征维度,比如说线性变换中的自变量个数即为 in_features = 10 , 因变量的个数为 out_features = 1 ,这两个个数即为单个样本的特征维度或者说是特征数。倒数第二个维度是样本的批量大小,像本例中输入样本为2,输出样本自然的对应也应该是2,输入样本的数量不需要单独指定,在传入模型处理的时候,模型会自动去识别处理。

1.7 模型内部的网络参数是怎么定义的,如何查看?

访问权重系数 :print(a.out.weight)

访问偏置系数 :print(a.out.bias)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值