1、AverageMeter 类
计算并存储平均值和当前值。
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
在初始化的时候就调用的重置方法reset。
当调用该类对象的update方法的时候就会进行变量更新。
top1 = AverageMeter('Acc@1', ':6.2f')
top1.update(acc1[0], image.size(0))
上述语句update的n为一个batchsize的大小,保存了top1准确率的值和均值。
2、accuracy()
计算top1 & top5的准确率
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True) # maxk-top-1;1-按行求最大/最小值;true-求最大值;true-排序;返回张量和索引
pred = pred.t() # 转置矩阵
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
在下文中已经详细介绍了
https://blog.csdn.net/weixin_45919003/article/details/130060658
3、evaluate()
测试集的准确率
def evaluate(model, criterion, data_loader, neval_batches):
model.eval()
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
cnt = 0
with torch.no_grad():
for image, target in data_loader:
output = model(image)
loss = criterion(output, target)
cnt += 1
acc1, acc5 = accuracy(output, target, topk=(1, 5))
print('.', end = '')
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
if cnt >= neval_batches:
return top1, top5
return top1, top5
4、load_model()
def load_model(model_file):
model = MobileNetV2()
state_dict = torch.load(model_file)
model.load_state_dict(state_dict)
model.to('cpu')
return model
使用案例:
float_model = load_model(saved_model_dir + float_model_file).to('cpu')
模型保存与重载模块的方法
- 保存模型参数和结构
torh.save(model, 'model_params.pth')
model = torch.load('model_params.pth')
- 仅保存参数
torh.save(model.state_dict(), 'model_params.pth')
model = torch.load_state_dict(torch.load('model_params.pth'))
上诉load_model()就是使用的方法2
5、print_size_of_model()
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p")/1e6)
os.remove('temp.p')
其中 getsize() 函数用于获取一个文件的大小,以字节的形式返回
- 1MB=1024KB=1048576字节(Byte)