pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。经常用于初始化init_weights的操作。如下apply递归调用_init_vit_weights,初始化ViT模型的子模块。
from torch import nn
#Weight init,初始化pos_embed
# trunc_normal_利用正态分布生成一个点,点在[a, b]区间之内
nn.init.trunc_normal_(self.pos_embed, std=0.02)
# Weight init,初始化cls_token
nn.init.trunc_normal_(self.cls_token, std=0.02)
# 调用vit初始函数
self.apply(_init_vit_weights)
def _init_vit_weights(m):
"""
ViT weight initialization
:param m: module
"""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)