1. 载入预训练模型
model = getSwinUnetr(patch_size, num_classes=num_classes, use_checkpoint=True)
model = model.to(device)
weight = torch.load("./model_swinvit.pt")
model.load_from(weights=weight)
print("Using pretrained self-supervied Swin UNETR backbone weights !")
2. 打印参数名称
for name in model.state_dict():
print(name)
输出:
......
......
decoder2.conv_block.conv1.conv.weight
decoder2.conv_block.conv2.conv.weight
decoder2.conv_block.conv3.conv.weight
decoder1.transp_conv.conv.weight
decoder1.conv_block.conv1.conv.weight
decoder1.conv_block.conv2.conv.weight
decoder1.conv_block.conv3.conv.weight
out.conv.conv.weight
out.conv.conv.bias
3. 冻结最后的out层
for layer in list(model.parameters())[:-2]:
layer.requires_grad = False
4. 打印没有冻结的层
此时应该只有out
层
for name, param in model.named_parameters():
if param.requires_grad:
print(name)