Pytorch:修改模型的特定模块/层

def _set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)

参数如下
model:模型
submodule_key:模型名称
module:用来替换的新模块
示例, 把resnet18的所有Conv2d层的卷积核第一维尺寸设为1:

model = models.resnet18()
for name, module in model.named_modules():
    if isinstance(module, Conv2d):
        kernel_size = (1, module.kernel_size[1]) \
            if isinstance(module.kernel_size, tuple) \
            else (1, module.kernel_size)
        stride = (1, module.stride[1]) \
            if isinstance(module.stride, tuple) \
            else (1, module.stride)
        padding = (0, module.padding[1]) \
            if isinstance(module.padding, tuple) \
            else (1, module.padding)

        _set_module(model, name, Conv2d(module.in_channels, module.out_channels, kernel_size=kernel_size,
                                        stride=stride, padding=padding))

参考链接:https://zhuanlan.zhihu.com/p/356273702,侵删

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值