OrderedDict,如何修改pytorch网络模型中的最后一层

Python中的OrderedDict是一种字典类型,它记住了元素插入的顺序。在Python 3.7及以后的版本中,普通的dict类型已经保证了插入顺序,但OrderedDict仍然有其独特的用途,比如它提供了一些额外的方法,如popitem(last=True),这在普通字典中是没有的。

以下是OrderedDict的一些基本用法:

  1. 创建OrderedDict: 创建一个空的OrderedDict或从一个序列的键值对中创建。
    from collections import OrderedDict
    
    # 创建一个空的OrderedDict
    ordered_dict = OrderedDict()
    
    # 或者从一个序列的键值对中创建
    ordered_dict = OrderedDict([('key1', 'value1'), ('key2', 'value2')])

    添加元素: 使用__setitem__方法或update方法添加元素。

  2. # 使用 __setitem__ 方法添加元素
    ordered_dict['key3'] = 'value3'
    
    # 使用 update 方法添加元素
    ordered_dict.update({'key4': 'value4'})

    访问元素: 使用键来访问元素。

  3. value = ordered_dict['key1']

修改官方VGG16模型最后一层的方法:


model.classifier._modules['6'] = nn.Linear(4096,len(classNames))   #修改vgg-16模型最后一层全连接层,输出目标类别个数

classifier就是最后这个线性层所在的地方,长这样:

       self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, num_classes),
        )

它相当于一个字典式的存储,用键值'6',可以取出对应‘值’。

而对于VGGFace模型来说。他的分类层是这样定义的:

self.fc = nn.ModuleDict(OrderedDict(
            {
                'fc6': nn.Linear(in_features=512 * 7 * 7, out_features=4096),
                'fc6-relu': nn.ReLU(inplace=True),
                'fc6-dropout': nn.Dropout(p=0.5),
                'fc7': nn.Linear(in_features=4096, out_features=4096),
                'fc7-relu': nn.ReLU(inplace=True),
                'fc7-dropout': nn.Dropout(p=0.5),
                'fc8': nn.Linear(in_features=4096, out_features=2622),
            }))

orderdict是一个保证了顺序的字典,

在Python中,特别是在使用PyTorch深度学习框架时,nn.ModuleDict是一个特殊的容器,用于存储多个nn.Module对象。nn.Module是PyTorch中所有神经网络模块的基类,而ModuleDict则是一个继承自collections.abc.MutableMapping的类,它允许你以字典的方式存储和访问这些模块。

nn.ModuleDict的主要特点包括:

  1. 字典接口:它提供了与普通字典相同的接口,可以像操作字典一样操作ModuleDict

  2. 模块存储ModuleDict中的每个键值对的值必须是nn.Module的实例。

  3. 键的自动注册:当你将一个nn.Module添加到ModuleDict时,它会自动注册到nn.Module_modules属性中,这意味着这个模块的参数会被自动收集并用于后续的训练和推理。

  4. 方便的访问:可以通过键名直接访问ModuleDict中的模块。

这里要修改VGGFace的最后一层得用

model.fc._modules['fc8'] = nn.Linear(4096, len(classNames))  # 修改vgg-16模型最后一层全连接层,输出目标类别个数

fc是最后分类层所在的‘这块的名字’,_modules必须要带上。'fc8'就是最后这一层对应的键值,能取出来最后一层

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值