前言:目前再做医学图像分类算法的项目,我们之前用了一些基本的模型去跑,例如resnet,densenet,efficientnet等。现在我们需要使用transformer模型,但是在测试的时候出现了如下图的错误,显示某个模块中没有某个属性,其实这是个常见错误。
看到这个大家先不要慌,没有这个属性,说明是模型中没有匹配代码中的属性,这个时候要找到对应的定义model的类。我在这里测试了vit(vision transformer),deit(data efficient image transformer),st(swin-transformer)的基准型base,在测试这3个的时候都要修改对应的连接层或属性以适应该类别的任务。
原来的代码是这样的,就会出现上述类似错误
class SELFMODEL(nn.Module): def __init__(self, model_name=params['model'], out_features=params['num_classes'], pretrained=True): super().__init__() self.model = timm.create_model(model_name, pretrained=pretrained) # 从预训练的库中加载模型 # 根据模型名称修改分类器层 if model_name[:3] == "res": n_features = self.model.fc.in_features self.model.fc = nn.Linear(n_features, out_features) elif model_name[:3] == "vit": n_features = self.model.head.in_features self.model.head = nn.Linear(n_features, out_features) elif model_name[:4] == "swin": n_features = self.model.head.in_features self.model.head = nn.Linear(n_features, out_features) elif model_name[:4] == "deit": n_features = self.model.head.in_features self.model.head = nn.Linear(n_features, out_features) elif model_name[:8] == "convnext": n_features = self.model.head.in_features self.model.head = nn.Linear(n_features, out_features) else: raise ValueError("ConvNeXt model does not have the expected 'head.fc' structure") else: n_features = self.model.classifier.in_features self.model.classifier = nn.Linear(n_features, out_features) # 打印模型结构 print(self.model)
模型的头部(head)是一个 Sequential
容器,其中包含一个或多个层,最后的 Linear
层用于分类。我们需要正确地访问这个 Linear
层并替换它。
修改之后就可以使用了,代码如下图所示
class SELFMODEL(nn.Module):
def __init__(self, model_name=params['model'], out_features=params['num_classes'],
pretrained=True):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained) # 从预训练的库中加载模型
# 根据模型名称修改分类器层
if model_name[:3] == "res":
n_features = self.model.fc.in_features
self.model.fc = nn.Linear(n_features, out_features)
elif model_name[:3] == "vit":
n_features = self.model.head.in_features
self.model.head = nn.Linear(n_features, out_features)
elif model_name[:4] == "swin":
n_features = self.model.head.in_features
self.model.head = nn.Linear(n_features, out_features)
elif model_name[:4] == "deit":
n_features = self.model.head.in_features
self.model.head = nn.Linear(n_features, out_features)
elif model_name[:8] == "convnext":
# ConvNeXt 的 head 是一个 Sequential 容器,其中包含一个名为 fc 的 Linear 层
if hasattr(self.model, 'head') and hasattr(self.model.head, 'fc'):
n_features = self.model.head.fc.in_features
self.model.head.fc = nn.Linear(n_features, out_features)
else:
raise ValueError("ConvNeXt model does not have the expected 'head.fc' structure")
else:
n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(n_features, out_features)
# 打印模型结构
print(self.model)