Pytorch构建类保存小批量训练精度

1. 前言

通过构建类来批量保存每次小批量训练时训练集的精度

2. 代码部分

类的实现

class Accumulator:
	def __init__(self, n):
		'''
		n 表示存储的变量数
		例如:样本总数,正确预测样本的个数等
		'''
		self.data = [0.0] * n # n个位置存放n个变量
	def add(self, *args):
		self.data = [a + float(b) for a, b in zip(self.data, args)]
	def reset(self):
		self.data = [0] * len(self.data)
	def __getitem__(self, index):
		# 魔法方法,使得实例化类之后可以被索引
		return self.data[index]

应用举例:

def train(model, train_iter, loss, optimizer):
	'''
	model: 表示训练的网络
	train_iter: 训练数据的迭代对象
	loss: 损失函数
	optimizer: 优化算法
	'''
	accumulator = Accumulator(2)
	for X, y in train_iter:
		optimizer.zero_grad()
		l = loss(model(X), y)
		l.backward()
		true_num = sum((model(X).astype(y.dtype) == y.dtype).astype(y.dtype))
		accumulator.add(true_num, y.numel())
	return accumulator[0] / accumulator[1]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值