def _set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)
参数如下
model:模型
submodule_key:模型名称
module:用来替换的新模块
示例, 把resnet18的所有Conv2d层的卷积核第一维尺寸设为1:
model = models.resnet18()
for name, module in model.named_modules():
if isinstance(module, Conv2d):
kernel_size = (1, module.kernel_size[1]) \
if isinstance(module.kernel_size, tuple) \
else (1, module.kernel_size)
stride = (1, module.stride[1]) \
if isinstance(module.stride, tuple) \
else (1, module.stride)
padding = (0, module.padding[1]) \
if isinstance(module.padding, tuple) \
else (1, module.padding)
_set_module(model, name, Conv2d(module.in_channels, module.out_channels, kernel_size=kernel_size,
stride=stride, padding=padding))
参考链接:https://zhuanlan.zhihu.com/p/356273702,侵删