1. 前言
通过构建类来批量保存每次小批量训练时训练集的精度
2. 代码部分
类的实现
class Accumulator:
def __init__(self, n):
'''
n 表示存储的变量数
例如:样本总数,正确预测样本的个数等
'''
self.data = [0.0] * n # n个位置存放n个变量
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0] * len(self.data)
def __getitem__(self, index):
# 魔法方法,使得实例化类之后可以被索引
return self.data[index]
应用举例:
def train(model, train_iter, loss, optimizer):
'''
model: 表示训练的网络
train_iter: 训练数据的迭代对象
loss: 损失函数
optimizer: 优化算法
'''
accumulator = Accumulator(2)
for X, y in train_iter:
optimizer.zero_grad()
l = loss(model(X), y)
l.backward()
true_num = sum((model(X).astype(y.dtype) == y.dtype).astype(y.dtype))
accumulator.add(true_num, y.numel())
return accumulator[0] / accumulator[1]