model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)# 剔除掉,不加载即可
ckpt.pop('head.weight')
ckpt.pop('head.bias')
model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)# 剔除掉,不加载即可
ckpt.pop('head.weight')
ckpt.pop('head.bias')