nn.Module详解

#最有用的两个功能
#1、查看模型结构:
print(mdoel)

#2、查看结构的模型参数
print(model.named_parameters())
print(model.subModel.conv1[0].weight.size)

# 详解nn.Module
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchsummary import summary

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = nn.Linear(20,40)
        self.conv1 = nn.Conv2d(1,20,5)
        self.conv2 = nn.Conv2d(20,20,5)


    def forward(self,x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
    
# add_module:往当前模块中添加子模块,所以主模块可以通过点号访问子模块,最后会加在self._modules[name] = module

# apply:可以访问所有子模块,通过.children()方法,对所有子模块的参数进行初始化
# eg:@torch.n-_grad()
# def init_weights(m):
#     print(m)
#     if type(m) == nn.Linear:
#         m.weight.fill_(1.0)

#         print(m.weight)
# net = nn.Sequential(nn.Linear(2,2),nn.Linear(2,2))
# # 遍历net中所有的子模块,将子模块传入该函数,然后对子模块的参数赋值
# net.apply(init_weights)

# buffers
# register_parameter:当前模块中添加参数,最后会self._parameters[name] = param

# Parameter:作为模块的参数,不要写tensor类型

# children:返回所有的子模块

# cpu/cuda:将所有的模型和参数都搬到相应设备上.通过apply,将所有的模块参数传入

# eval:将模型设置为评估模式,会遍历所有子模块设为false。影响DropOut,BatchNorm上有区别

# get_parameter:根据字符串得到一个模型里面的参数,先解析字符串得到module_path

# get_submodule :根据字符串得到子模块

# load_state_dict:加载模型的所有参数和buffer
# state_dict:每个模块都有参数和buffer变量
# _save_to_state_dict:保存当前模块的参数和buffer
# _load_from_state_dict:读取当前模块参数和buffer

# named_parameters():模型中所有的参数和值
# for name, param in self.named_parameters():
#     print(name,param)

# requies_grad:参数是否需要更新

# type:将所有的模型参数转化数据类型

# to_empty 

# to(torch.double):所有的浮点数据类型都转换成float64

# zero_grad:清除梯度

# __repr__:魔法方法 str(model)

# __dir__:dir(model)当前模块的名称和参数


if __name__=='__main__':
    model = Model()
    model.eval()
    print(summary(model,(1,28,28)))
    # print(type(A.children()))
    for name, param in model.named_parameters():
        print(name,param.size())

    # 打印所有子模块名字和特征参数,返回的是有序字典
    model._modules()
    # 返回自身和子模块,跟_modules()相比,多了一个自身模块
    for p in model.named_modules():
        print(p)
    # 打印'linear1'层的特征参数
    model._modules['linear1']
    # 打印 linear1里面的weight参数
    model._modules['linear1'].weight
    # 打印数据类型
    model._modules['linear1'].weight.dtype

    # 返回空字典,当前模块的参数,而不是遍历子模块里面的参数
    model._parameters
    # 包含当前模块参数和子模块的参数
    for p in model.parameters:
        print(p)
    # 字典,包含参数名字和参数值
    for name, param in model.named_parameters():
        print(name,param.size())
    model._buffers

    # 子模块的名称和子模块本身,与_modules只是返回的格式不一样,其他无区别
    for p in model.named_children():
        print(p)



    # 所有的参数名字和参数
    model.state_dict()


    # 只能保存模型的参数,不能保存优化器参数
    torch.save(model.state_dict(),'model_weight.pth')
    # 加载模型的参数
    model.load_state_dict(torch.load('model_weights.pth'))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值