def train_model(model, criterion, optimizer, scheduler, num_epochs=36):
"""
Train the model.
Params:
model: torch.nn.Module
Neural network model
criterion: torch.nn.Module
Loss function
optimizer: torch.optim.Optimizer
Optimization strategy
scheduler: torch.optim.lr_scheduler._LRScheduler
Learning rate scheduler
num_epochs: int
Number of epochs for training
Returns:
model: torch.nn.Module
Trained model
"""
since = time.time()
# initialize the best accuracy and its corresponding model
best_acc = 0.0
if use_gpu and len(device_ids) > 1: # whether the model has been packed as Parallel
best_model_wts = copy.deepcopy(model.module.state_dict())
else:
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(start_epoch, num_epochs + 1):
epoch_since = time.time()
print('Epoch {}/{}'.format(epoch, num_epochs))
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
tic_batch = time.time()
# Iterate over data
for i, (inputs, labels) in enumerate(dataloaders[phase]):
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward, and track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
# record training loss for each iteration
curr_iter = (epoch - 1) * len(dataloaders['train']) + i + 1
writer.add_scalar('Training_Loss', loss.item(), curr_iter)
if i % args.print_freq == 0:
print(
'Epoch {}/{}-batch:{}/{} lr:{:.4f} {} Loss: {:.6f} Acc: {:.4f} Time: {:.4f}batch/sec'.format(
epoch, num_epochs, i, round(dataset_sizes[phase] / args.batch_size) - 1,
scheduler.get_lr()[0], phase,
loss.item(), torch.sum(preds == labels.data).item() / labels.size(0),
args.print_freq / (time.time() - tic_batch)
)
)
tic_batch = time.time()
epoch_loss = running_loss / dataset_sizes[phase] # "dataset_size" used as an outer variable
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# record training acc and testing acc for each epoch
writer.add_scalars('accuracy', {phase: epoch_acc}, epoch)
# deep copy the model state_dict with highest accuracy in val
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
if use_gpu and len(device_ids) > 1: # whether the model has been packed as Parallel
best_model_wts = copy.deepcopy(model.module.state_dict())
else:
best_model_wts = copy.deepcopy(model.state_dict())
scheduler.step(epoch)
if epoch % args.save_epoch_freq == 0:
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
checkpoint = {
'model_state_dict': best_model_wts, # save the model state_dict of the highest acc
'optim_state_dict': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth"))
# record distribution of weight for conv and fc layers
for name, param in model.named_parameters():
layer, attr = os.path.splitext(name)
if 'conv' in layer or 'fc' in layer:
writer.add_histogram('{}_{}'.format(layer, attr[1:]), param, epoch)
epoch_elapsed = time.time() - epoch_since
print('Time taken for the epoch: {:.0f}m {:.0f}s'.format(epoch_elapsed // 60, epoch_elapsed % 60))
time_elapsed = time.time() - since
print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
print('-' * 30)