【翻译】class torch.nn.ModuleDict(modules=None)

参考链接: class torch.nn.ModuleDict(modules=None)

说明:在PyTorch的1.2.0版本上这个方法有bug,用一个ModuleDict对象调用update()来更新另一个ModuleDict会报错,但在PyTorch的1.7.1版本上可以正常使用.

在这里插入图片描述

在这里插入图片描述

原文及翻译:

ModuleDict  ModuleDict章节

class torch.nn.ModuleDict(modules=None)
类型 class torch.nn.ModuleDict(modules=None)
    Holds submodules in a dictionary.
    该类能够以字典的方式持有子模块.
    ModuleDict can be indexed like a regular Python dictionary, but modules it contains 
    are properly registered, and will be visible by all Module methods.
    ModuleDict 类型能够像普通Python字典一样被索引访问,但是它和普通Python字典不同的是,该类型所
    包含的模块会被正确地注册登记,并且这些模块能被所有地Module模块方法可见.
    ModuleDict is an ordered dictionary that respects
    ModuleDict 类型是一个有序字典,它遵循:
        the order of insertion, and
        插入地先后顺序,并且
        in update(), the order of the merged OrderedDict or another ModuleDict 
        (the argument to update()).
		在方法update(),遵循被合并的有序字典OrderedDict的顺序或者
		另一个ModuleDict(,传递给方法update()的参数)的顺序.
    Note that update() with other unordered mapping types (e.g., Python’s plain dict) does 
    not preserve the order of the merged mapping.
    值得注意的是,在这个update()方法中如果传递了一个无序的映射类型(比如,Python的普通字典),那么不会
    保持被合并的这个映射类型的顺序.

    Parameters  参数
        modules (iterable, optional) – a mapping (dictionary) of (string: module) or 
        an iterable of key-value pairs of type (string, module)
		modules (iterable可迭代类型, 可选) – 一个映射(字符串:模块)类型(字典)或者
		一个键值对(字符串,模块)类型的可迭代对象.


    Example:  例子:

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.choices = nn.ModuleDict({
                    'conv': nn.Conv2d(10, 10, 3),
                    'pool': nn.MaxPool2d(3)
            })
            self.activations = nn.ModuleDict([
                    ['lrelu', nn.LeakyReLU()],
                    ['prelu', nn.PReLU()]
            ])

        def forward(self, x, choice, act):
            x = self.choices[choice](x)
            x = self.activations[act](x)
            return x

    clear()
    方法: clear()
        Remove all items from the ModuleDict.
        移除ModuleDict中的所有项目.

    items()
    方法: items()
        Return an iterable of the ModuleDict key/value pairs.
        返回一个ModuleDict的键/值对的可迭代对象.

    keys()
    方法: keys()
        Return an iterable of the ModuleDict keys.
        返回ModuleDict关键字的可迭代对象.

    pop(key)
    方法: pop(key)
        Remove key from the ModuleDict and return its module.
        在ModuleDict中移除关键字key.并且返回这个关键字对应的模块.
        Parameters  参数
            key (string) – key to pop from the ModuleDict
            
    update(modules)
    方法: update(modules)
        Update the ModuleDict with the key-value pairs from a mapping or an iterable, 
        overwriting existing keys.
        用一个键值对的映射类型或者可迭代对象来更新ModuleDict,覆写已经存在的关键字.
        Note  注意:
        If modules is an OrderedDict, a ModuleDict, or an iterable of key-value pairs, 
        the order of new elements in it is preserved.
        如果modules参数是一个有序字典OrderedDict或者ModuleDict或者键值对的可迭代对象,
        那么新元素的顺序也同样被维持.
        Parameters  参数
            modules (iterable) – a mapping (dictionary) from string to Module, or an 
            iterable of key-value pairs of type (string, Module)
            modules (iterable可迭代对象) – 字符串映射到Module模块的映射类型(字典),或者
            是键值对(字符串,Module模块)类型的可迭代对象.

    values()
    方法: values()
        Return an iterable of the ModuleDict values.
        返回一个ModuleDict键值对中的值的可迭代对象.

代码实验展示:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>>
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000021D828AD330>
>>> import torch.nn as nn
>>> layers_1 = nn.ModuleDict({
...                 'conv': nn.Conv2d(10, 10, 3),
...                 'pool': nn.MaxPool2d(3)
...         })
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1['conv']
Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
>>> layers_1['pool']
MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
>>>
>>> for item in layers_1.items():
...     print(item)
...
('conv', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)))
('pool', MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False))
>>>
>>> for key in layers_1.keys():
...     print(key)
...
conv
pool
>>> for value in layers_1.values():
...     print(value)
...
Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
>>> layers_1.pop()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: pop() missing 1 required positional argument: 'key'
>>> layers_1.pop('conv')
Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
>>> layers_1
ModuleDict(
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1.pop('pool')
MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
>>> layers_1
ModuleDict()
>>>
>>> layers_1 = nn.ModuleDict({
...                 'conv': nn.Conv2d(10, 10, 3),
...                 'pool': nn.MaxPool2d(3)
...         })
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1.clear()
>>> layers_1
ModuleDict()
>>> layers_1 = nn.ModuleDict({
...                 'conv': nn.Conv2d(10, 10, 3),
...                 'pool': nn.MaxPool2d(3)
...         })
>>>
>>> layers_2 = nn.ModuleDict({
...             'conv': nn.Conv2d(5, 5, 5),
...             'pool2': nn.MaxPool2d(7)
...     })
>>>

代码实验:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000002C864EB7870>
>>>
>>> print(torch.__version__)
1.7.1
>>>
>>> layers_1 = nn.ModuleDict({
...     'conv': nn.Conv2d(10, 10, 3),
...     'pool': nn.MaxPool2d(3)
... })
>>>
>>>
>>> layers_2 = nn.ModuleDict([
...     ['lrelu', nn.LeakyReLU()],
...     ['prelu', nn.PReLU()]
... ])
>>>
>>> layers_2
ModuleDict(
  (lrelu): LeakyReLU(negative_slope=0.01)
  (prelu): PReLU(num_parameters=1)
)
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
>>> layers_1.update(layers_2)
>>> layers_1
ModuleDict(
  (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (lrelu): LeakyReLU(negative_slope=0.01)
  (prelu): PReLU(num_parameters=1)
)
>>> layers_2
ModuleDict(
  (lrelu): LeakyReLU(negative_slope=0.01)
  (prelu): PReLU(num_parameters=1)
)
>>>
>>>
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torch.nn.Linear是PyTorch中的一个模块,用于定义一个线性层。它接受两个参数,即输入和输出的维度。通过调用torch.nn.Linear(input_dim, output_dim),可以创建一个线性层,其中input_dim是输入的维度,output_dim是输出的维度。Linear模块的主要功能是执行线性变换,将输入数据乘以权重矩阵,并加上偏置向量。这个函数的具体实现可以参考PyTorch官方文档中的链接。 在引用中的示例中,linear1是一个Linear模块的实例。可以通过print(linear1.weight.data)来查看linear1的权重。示例中给出了权重的具体数值。 在引用中的示例中,x是一个Linear模块的实例,输入维度为5,输出维度为2。通过调用x(data)来计算线性变换的结果。在这个示例中,输入data的维度是(5,5),输出的维度是(5,2)。可以使用torch.nn.functional.linear函数来实现与torch.nn.Linear相同的功能,其中weight和bias分别表示权重矩阵和偏置向量。 以上是关于torch.nn.Linear的一些介绍和示例。如果需要更详细的信息,可以参考PyTorch官方文档中关于torch.nn.Linear的说明。 https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [torch.nn.Linear详解](https://blog.csdn.net/sazass/article/details/123568203)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [torch.nn.Linear](https://blog.csdn.net/weixin_41620490/article/details/127833324)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [pytorch 笔记:torch.nn.Linear() VS torch.nn.function.linear()](https://blog.csdn.net/qq_40206371/article/details/124473437)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值