输入
data_dir输入设置为数据集的根目录
model_name:[resnet, alexnet, vgg, squeezenet, densenet, inception]
num_classes是数据集中的类数,batch_size是用于训练的批次大小,可以根据您计算机的能力进行调整,num_epochs是我们要运行的训练时期的数量,以及feature_extract是一个布尔值,它定义了我们是微调还是特征提取.如果feature_extract = False,则微调模型并更新所有模型参数。 如果feature_extract = True,则仅更新最后一层参数,其他参数保持固定。
# Top level data directory. Here we assume the format of the directory conforms
# to the ImageFolder structure
data_dir = "./data/hymenoptera_data"
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "squeezenet"
# Number of classes in the dataset
num_classes = 2
# Batch size for training (change depending on how much memory you have)
batch_size = 8
# Number of epochs to train for
num_epochs = 15
# Flag for feature extracting. When False, we finetune the whole model,
# when True we only update the reshaped layer params
feature_extract = True
辅助函数
train_model函数处理给定模型的训练和验证。作为输入,它采用PyTorch模型,数据加载器字典,损失函数,优化器,要训练和验证的指定时期数以及当模型是Inception模型时的布尔标志。 is_inception标志用于适应Inception v3模型,因为该体系结构使用辅助输出,并且总体模型损失同时考虑了辅助输出和最终输出,如此处所述。 该函数针对指定的时期数进行训练,并且在每个时期之后运行完整的验证步骤。 它还跟踪最佳模型(在验证准确性方面),并且在训练结束时返回最佳模型。 在每个时期之后,将打印训练和验证准确性。
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
初始化
打印预训练网络结构,重新定义最后一层,如
#resnet
model.fc = nn.Linear(512, num_classes)
#Alexnet,VGG
model.classifier[6] = nn.Linear(4096,num_classes)
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
input_size = 0
if model_name == "vgg":
""" VGG11_bn
"""
model_ft = models.vgg11_bn(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
input_size = 224
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
# Print the model we just instantiated
print(model_ft)