3.1 源码解析
train.py 讲解
此代码为C3D模型的训练部分,分为训练前的准备,和训练部分两大部分。
1.训练前的准备
1.1 参数的设置
nEpochs = 101 # Number of epochs for training
resume_epoch = 0 # Default is 0, change if want to resume 即参数改变重头训练
useTest = True # See evolution of the test set when training
nTestInterval = 20 # Run on test set every nTestInterval epochs
snapshot = 25 # Store a model every snapshot epochs
lr = 1e-5 # Learning rate
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) # save_dir_root = '...\\C3D'
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1] # exp_name = '...\\C3D'
此部分为一些参数的设置
os.path.dirname(–file–) 获取当前运行脚本的路径
1.2 模型和数据集的载入
model = C3D_model.C3D(num_classes=num_classes, pretrained=False)
train_params = [{
'params': C3D_model.get_1x_lr_params(model), 'lr': lr},
{
'params': C3D_model.get_10x_lr_params(model), 'lr': lr * 10}]
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4) # 优化方法,梯度下降
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,
gamma=0.1)
# 加载数据集
train_dataloader = DataLoader(VideoDataset(dataset=