pytorch预训练模型

输入

data_dir输入设置为数据集的根目录
model_name:[resnet, alexnet, vgg, squeezenet, densenet, inception]
num_classes是数据集中的类数,batch_size是用于训练的批次大小,可以根据您计算机的能力进行调整,num_epochs是我们要运行的训练时期的数量,以及feature_extract是一个布尔值,它定义了我们是微调还是特征提取.如果feature_extract = False,则微调模型并更新所有模型参数。 如果feature_extract = True,则仅更新最后一层参数,其他参数保持固定。

# Top level data directory. Here we assume the format of the directory conforms
#   to the ImageFolder structure
data_dir = "./data/hymenoptera_data"

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "squeezenet"

# Number of classes in the dataset
num_classes = 2

# Batch size for training (change depending on how much memory you have)
batch_size = 8

# Number of epochs to train for
num_epochs = 15

# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True

辅助函数

train_model函数处理给定模型的训练和验证。作为输入,它采用PyTorch模型,数据加载器字典,损失函数,优化器,要训练和验证的指定时期数以及当模型是Inception模型时的布尔标志。 is_inception标志用于适应Inception v3模型,因为该体系结构使用辅助输出,并且总体模型损失同时考虑了辅助输出和最终输出,如此处所述。 该函数针对指定的时期数进行训练,并且在每个时期之后运行完整的验证步骤。 它还跟踪最佳模型(在验证准确性方面),并且在训练结束时返回最佳模型。 在每个时期之后,将打印训练和验证准确性。

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

初始化

打印预训练网络结构,重新定义最后一层,如

#resnet
model.fc = nn.Linear(512, num_classes)
#Alexnet,VGG
model.classifier[6] = nn.Linear(4096,num_classes)
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0
    if model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
    else:
        print("Invalid model name, exiting...")
        exit()
    return model_ft, input_size
# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# Print the model we just instantiated
print(model_ft)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值