pytorch量化的辅助模块

1、AverageMeter 类

计算并存储平均值和当前值。

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

初始化的时候就调用的重置方法reset。
当调用该类对象的update方法的时候就会进行变量更新。

top1 = AverageMeter('Acc@1', ':6.2f')
top1.update(acc1[0], image.size(0))

上述语句update的n为一个batchsize的大小,保存了top1准确率的值和均值。

2、accuracy()

计算top1 & top5的准确率

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)                            
        batch_size = target.size(0)                 

        _, pred = output.topk(maxk, 1, True, True)  # maxk-top-1;1-按行求最大/最小值;true-求最大值;true-排序;返回张量和索引
        pred = pred.t()                             # 转置矩阵
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

在下文中已经详细介绍了
https://blog.csdn.net/weixin_45919003/article/details/130060658

3、evaluate()

测试集的准确率

def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

4、load_model()

def load_model(model_file):
    model = MobileNetV2()
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

使用案例:

float_model = load_model(saved_model_dir + float_model_file).to('cpu')

模型保存与重载模块的方法

  • 保存模型参数和结构
torh.save(model, 'model_params.pth')
model = torch.load('model_params.pth')
  • 仅保存参数
torh.save(model.state_dict(), 'model_params.pth')
model = torch.load_state_dict(torch.load('model_params.pth'))

上诉load_model()就是使用的方法2

5、print_size_of_model()

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

其中 getsize() 函数用于获取一个文件的大小,以字节的形式返回

  • 1MB=1024KB=1048576字节(Byte)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值