网上找到的都是几年前的了,更新一下最近的用法和踩的坑
模型加载
用torchvision.models.XXX加载模型时用参数pretrained=True好像可以运行,但是会报红,用weights没问题,下面举了两个例子
pretrained_net = torchvision.models.resnet18(weights = torchvision.models.ResNet18_Weights.DEFAULT)
pretrained_net = torchvision.models.vgg19(weights = torchvision.models.VGG19_Weights.DEFAULT)
截取部分层数
以vgg19为例子
net = nn.Sequential(*list(pretrained_net.children())[:-1])
返回的是features和avgpool两层,也就是.children返回的是模型的所有顶层模块
而用如下代码
net = nn.Sequential(*[pretrained_net.features[i] for i in range(20)])
则是直接用 顶层模块[ ]才能直接取出 features
容器中的前 20 层
参数修改没有研究过,后续需要再来补坑