输入图片格式为:(b,c,224,224)
修改输出层的输出维度:
import torchvision
resnet_model = torchvision.models.resnet18(pretrained=True)
for param in resnet_model.parameters():
param.requires_grad = False
resnet_model.fc
class Net(nn.Module):
def __init__(self, model):
super(Net, self).__init__()
# 取掉model的后1层
self.resnet_layer = nn.Sequential(*list(model.children())[:-1])
self.Linear_layer = nn.Linear(512, 11) #加上一层参数修改好的全连接层,例如修改为11层
def forward(self, x):
x = self.resnet_layer(x)
x = x.view(x.size(0), -1)
x = self.Linear_layer(x)
return x
resnet_model = Net(resnet_model)
resnet18的模型使用
最新推荐文章于 2024-02-03 11:33:09 发布