PyTorch的state_dict、parameters、modules、nn.Sequential及ModuleList源码学习(学习笔记)

可以进行原视频的观看,此为视频推荐:8、深入剖析PyTorch的state_dict、parameters、modules源码_哔哩哔哩_bilibili

 9、深入剖析PyTorch的nn.Sequential及ModuleList源码_哔哩哔哩_bilibili

以下代码是我的学习笔记,如有出错或者不懂,请指正,同时推荐可以观看以上视频来解惑↑↑↑

下面代码是state_dict、parameters、modules源码使用

import torch.nn


class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.linear1 = torch.nn.Linear(2,3)
        self.linear2 = torch.nn.Linear(3,4)
        self.batch_norm = torch.nn.BatchNorm2d(4)

test_module = Test()

# print(test_module._modules)  #返回字典 -->> OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
# print(test_module._modules['linear1'].weight) #此时是float32

test_module.to(torch.double) #所有变为双精度
# print(test_module._modules['linear1'].weight)#此时是float64

test_module.to(torch.float32)
# print(test_module._parameters)
# print(test_module._buffers) #由于未调用,所以俩个均为空字典

# print(test_module.state_dict()) #打印出四个权重,以及batch_norm.weight,batch_norm.bias

# print(test_module.state_dict()['linear1.weight']) #用键可以找出相应值,得到对应张量值tensor([[ 0.4146, -0.6774],
                                                                                    # [ 0.6006,  0.3286],
                                                                                    # [ 0.6237,  0.6902]])
# for p in test_module.parameters():
#     print(p) #可以的到test_modules的所有参数

# 以下可能更加清楚
# print(len(test_module._parameters())) #有下划线返回当前,而不包含子module参数
# print(len(test_module.parameters())) #无下划线返回当前和子module参数,此时报错

# for p in test_module.named_parameters():
#     print(p) #可以的到键值对,当前张量在哪一个模块,什么权重,更加清晰

# for p in test_module.named_children():
#     print(p)  #返回元组,而._modules返回的是字典,此为俩者区别

# for p in test_module.named_modules():
#     print(p)  #返回四个模块,会返回自己这个模块,而_modules只会返回子模块(即3个)
#     print('\n')

#print(str(test_module))
#print(dir(test_module)) #可以查看参数,模块名称,buffer,类的名称

然后是self.training源码解析

以下是视频讲解链接:9、深入剖析PyTorch的nn.Sequential及ModuleList源码_哔哩哔哩_bilibili

以下是在pytorch中子类出现过的,链接如下:Dropout — PyTorch 2.0 documentation(dropout可以是网络在正向传播过程中随机失活一部分神经元)

torch.nn.modules.dropout — PyTorch 2.0 documentation

torch.nn.modules.batchnorm — PyTorch 2.0 documentation

class Dropout1d(_DropoutNd):
    def forward(self, input: Tensor) -> Tensor:
        return F.dropout1d(input, self.p, self.training, self.inplace)

进行推理模式,所有均设置为False

 torch.nn.Sequential讲解

讲解博客:torch.nn.Sequential和torch.nn.ModuleList_MoMona_W的博客-CSDN博客

import torch.nn

s = torch.nn.Sequential(torch.nn.Linear(2,3),(torch.nn.Linear(3,4)))
print(s) #返回0,1是因为没有使用字典方式,字典方式可查网上教程

torch.nn.ModuleList讲解:存放模型的列表

讲解博客:torch.nn.Sequential和torch.nn.ModuleList_MoMona_W的博客-CSDN博客

 MyModulelist

 Parameterlist

Parameterdict

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值