Pytorch学习-nn.functional和nn.Module

一,nn.functional 和 nn.Module

Pytorch和神经网络相关的功能组件大多都封装在 torch.nn模块下。

这些功能组件的绝大部分既有函数形式实现,也有类形式实现。

其中nn.functional(一般引入后改名为F)有各种功能组件的函数实现。例如:

(激活函数)

F.relu F.sigmoid F.tanh F.softmax (模型层)

F.linear F.conv2d F.max_pool2d F.dropout2d F.embedding (损失函数)

F.binary_cross_entropy F.mse_loss F.cross_entropy 为了便于对参数进行管理,一般通过继承 nn.Module 转换成为类的实现形式,并直接封装在 nn 模块下。例如:

(激活函数)

nn.ReLU nn.Sigmoid nn.Tanh nn.Softmax (模型层)

nn.Linear nn.Conv2d nn.MaxPool2d nn.Dropout2d nn.Embedding (损失函数)

nn.BCELoss nn.MSELoss nn.CrossEntropyLoss 实际上nn.Module除了可以管理其引用的各种参数,还可以管理其引用的子模块,功能十分强大。

二,使用nn.Module来管理参数

import torch
from torch import nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
# nn.Parameter 具有requires_grad = True 属性
w  = nn.Parameter(torch.randn(2,2))
print(w)
print(w.requires_grad)
Parameter containing:
tensor([[-1.5365,  0.0448],
        [-0.3267,  0.3325]], requires_grad=True)
True
#nn.parameterList 可以将多个nn.parameter组成一个列表
params_list = nn.ParameterList([nn.Parameter(torch.rand(8,i)) for i in range(1,3)])
print(params_list)
print(params_list[0])
print(params_list[0].requires_grad)
ParameterList(
    (0): Parameter containing: [torch.FloatTensor of size 8x1]
    (1): Parameter containing: [torch.FloatTensor of size 8x2]
)
Parameter containing:
tensor([[0.0708],
        [0.6856],
        [0.5501],
        [0.7626],
        [0.8094],
        [0.6685],
        [0.6888],
        [0.0835]], requires_grad=True)
True
#nn.ParameterDict 可以将nn.Parameter 组成一个字典
params_dict = nn.ParameterDict({
   "a":nn.Parameter(torch.rand(2,2)),"b":nn.Parameter(torch.zeros(2))})
print(params_dict)
print(params_dict["a"])
print(params_dict["a"].requires_grad)
ParameterDict(
    (a): Parameter containing: [torch.FloatTensor of size 2x2]
    (b): Parameter containing: [torch.FloatTensor of size 2]
)
Parameter containing:
tensor([[0.1196, 0.0379],
        [0.1089, 0.5946]], requires_grad=True)
True
#可以用Module将他们管理
#Module.parameter()返回一个生成器,包括其结构下的所有parameters
module = nn.Module()
module.w = w
module.params_list = params_list
module.params_dict = params_dict

num_param = 0
for param in module.parameters():
    print(param,'\n')<
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值