主函数
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers,
pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion)
return
#训练次数
for epoch in range(args.start_epoch, args.epochs):
#每过100次迭代学习率下降10倍
adjust_learning_rate(optimizer, epoch)
# 训练网络
train(train_loader, model, criterion, optimizer, epoch)
# 每25次迭代在验证集上验证
prec1 = 0.0
if (epoch + 1) % args.save_freq == 0:
prec1 = validate(val_loader, model, criterion)
#将其与之前的模型的验证准确率比较,得到最高的准确率
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
#先保存本次训练模型,若此次模型是最佳模型,则将其拷贝到best_mode中
if (epoch + 1) % args.save_freq == 0:
checkpoint_name = "%s_%s%3d_%03d_%s" % (args.modality,"split",args.split,epoch + 1, "checkpoint.pth.tar")
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best, checkpoint_name, args.resume)
1、训练函数
首先定义三个参数batch_time表示在已经训练的数据集上平均训练一张图片花费的时间,losses表示在一个平均训练一张图片的损失,top1为在训练集上的平均准确率。而loss_mini_batch则表示一个batch_size上的平均损失,acc_mini_batch表示一个batch_size上的平均准确率。因为在从dataloader中取数据训练时,每次取的是batch_size个数据,但是在计算梯度回传时是计算的batch_size*iter_size个数据。
def train(train_loader, model, criterion, optimizer, epoch):
#定义一个batch训练时间、总损失、总准确率变量
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
# model转向训练模式
model.train()
end = time.time()
optimizer.zero_grad()
#一个batch上的loss和准确率初始化
loss_mini_batch = 0.0
acc_mini_batch = 0.0
for i, (input, target) in enumerate(train_loader):
#一次从loader中读取batch_size个数据(对于rgb是25个3*H*W的tensor,flow是25个20*H*w的tensor),target是1*25
input = input.float().cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
#得到网络预测输出
output = model(input_var)
# 测试一个batch_size的top1和top3准确率以及损失
prec1, prec3 = accuracy(output.data, target, topk=(1, 3))
acc_mini_batch += prec1.item()
loss = criterion(output, target_var)
loss = loss / args.iter_size
loss_mini_batch += loss.item()
#损失反向传播
loss.backward()
if (i+1) % args.iter_size == 0:
# compute gradient and do SGD step
optimizer.step()
optimizer.zero_grad()
losses.update(loss_mini_batch, input.size(0))
top1.update(acc_mini_batch/args.iter_size, input.size(0))
batch_time.update(time.time() - end)
end = time.time()
loss_mini_batch = 0
acc_mini_batch = 0
if (i+1) % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i+1, len(train_loader)+1, batch_time=batch_time, loss=losses, top1=top1))
2、验证函数
与训练函数类似batch_time,losses、top1、top3表示是已经测试的数据中的平均测试时间、平均损失和平均正确率。
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top3 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
input = input.float().cuda(async=True)
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec3 = accuracy(output.data, target, topk=(1, 3))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top3.update(prec3.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@3 {top3.val:.3f} ({top3.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top3=top3))
print(' * Prec@1 {top1.avg:.3f} Prec@3 {top3.avg:.3f}'
.format(top1=top1, top3=top3))
return top1.avg
3、计算top1,top3准确率
把这一块拿出来说是因为,target是1*batch_size大小,开始我在找是在哪里对target进行变化,计算正确率的,因此该函数花了点时间研究。在这里k=3,即求前3个可能性最大的结果只要有一个与正确结果相同,就算他预测成功,这样得出来的是top3准确率,我之前一直用的是top1准确率,也是通过这个项目,才认识了topk。
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
"""
求topk准确率,先将output排序求前topk个可能性最大对应的结果
即将batch_size*numofclass -----> batchsize*topk,转置得到predict = topk*batchsize
target是正确的类别其大小为1*batchsize
而预测结果是topk*batchsize,因为将target复制topk行将其expand为topk*batchsize
求target-predict ,第1行sum(0)/batchsize*100 为top1准确率,以此类推
"""
maxk = max(topk)
batch_size = target.size(0)
#返回最大的k个结果,output----> batchsize*numofclass
_, pred = output.topk(maxk, 1, True, True)
#pred---->batchsize*k,将pred转置-----> k * batchsize
pred = pred.t()
#将target扩展成1*batchsize后,再复制成三行与pred大小相同,比较得到预测正确个数
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
#correct_k*100/batchsize
res.append(correct_k.mul_(100.0 / batch_size))
#返回top 1-k的准确率
return res
4、其他函数
#保存模型
def save_checkpoint(state, is_best, filename, resume_path):
cur_path = os.path.join(resume_path, filename)
best_path = os.path.join(resume_path, args.arch+'_model_best.pth.tar')
torch.save(state, cur_path)
if is_best:
shutil.copyfile(cur_path, best_path)
#参数管理,reset、update(update只要是求平均值)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
#调整学习率的大小
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 150 epochs
当epoch超过args.lr_steps[0],学习率降低10倍,超过args.lr_steps[1]学习率降低100倍
"""
decay = 0.1 ** (sum(epoch >= np.array(args.lr_steps)))
lr = args.lr * decay
print("Current learning rate is %4.6f:" % lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr