batch & print pro_keras train_on_batch详解

010d3358eae972b168d60956414561e3.png

学习过程中,遇到一些不明白的地方,随笔记录一下。

利用 train_on_batch 精细管理训练过程

大部分使用 keras 的同学使用 fit() 或者 fit_generator() 进行模型训练, 这两个 api 对于刚接触深度学习的同学非常友好和方便,但是由于其是非常深度的封装,对于希望自定义训练过程的同学就显得不是那么方便(从 torch 转 keras 的同学可能更喜欢自定义训练过程),而且,对于 GAN 这种需要分步进行训练的模型,也无法直接使用 fit 或者 fit_generator 直接训练的。因此,keras 提供了 train_on_batch 这个 api,对一个 mini-batch 的数据进行梯度更新。
总结优点如下:

  • 更精细自定义训练过程,更精准的收集 loss 和 metrics
  • 分步训练模型-GAN的实现
  • 多GPU训练保存模型更加方便
  • 更多样的数据加载方式,结合 torch dataloader 的使用

1. train_on_batch 的输入输出

1.1 输入

y_pred = Model.train_on_batch(
    x,
    y=None,
    sample_weight=None,
    class_weight=None,
    reset_metrics=True,
    return_dict=False,
)
  • x:模型输入,单输入就是一个 numpy 数组, 多输入就是 numpy 数组的列表
  • y:标签,单输出模型就是一个 numpy 数组, 多输出模型就是 numpy 数组列表
  • sample_weight:mini-batch 中每个样本对应的权重,形状为 (batch_size)
  • class_weight:类别权重,作用于损失函数,为各个类别的损失添加权重,主要用于类别不平衡的情况, 形状为 (num_classes)
  • reset_metrics:默认True,返回的metrics只针对这个mini-batch, 如果False,metrics 会跨批次累积
  • return_dict:默认 False, y_pred 为一个列表,如果 True 则 y_pred 是一个字典

1.2 输出

  • 单输出模型,且只有loss,没有metrics, 此时 y_pred 为一个标量,代表这个 mini-batch 的 loss, 例如下面的例子
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'])
y_pred = model.train_on_batch(x=image,y=label) # y_pred is a scalar
# y_pred 为标量
  • 单输出模型,既有loss,也有metrics, 此时 y_pred 为一个列表,代表这个 mini-batch 的 loss 和 metrics, 列表长度为 1+len(metrics), 例如下面的例子
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'], metrics=['accuracy'])
y_pred = model.train_on_batch(x=image,y=label) # len(y_pred) == 2
# y_pred 为长度为2的列表, y_pred[0]为loss, y_pred[1]为accuracy
  • 多输出模型,既有loss,也有metrics, 此时 y_pred 为一个列表,列表长度为 1+len(loss)+len(metrics), 例如下面的例子
model = keras.models.Model(inputs=inputs, outputs=[output1, output2])
model.compile(Adam, loss=['binary_crossentropy', 'binary_crossentropy'], 
			  metrics=['accuracy', 'accuracy'])
y_pred = model.train_on_batch(x=image,y=label) # len(y_pred) == 5
# y_pred[0]为总loss(按照loss_weights加权),
# y_pred[1]为第一个输出的loss, y_pred[2]为第二个输出的loss
# y_pred[3]为第一个accuracy,y_pred[4]为第二个accuracy

2. train_on_batch 多GPU训练模型

2.1 多GPU模型初始化,加载权重,模型编译,模型保存

import tensorflow as tf
import keras
import os

# 初始化GPU的使用个数
gpu = "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
gpu_num = len(gpu.split(','))

# model初始化
if gpu_num >= 2: # gpu_num表示GPU的数量
	with tf.device('/cpu:0'): # 使用多GPU时,先在CPU上初始化模型
		model = YourModel(input_size, num_classes)
		model.load_weights('*.h5') # 如果有权重需要加载,在这里实现
	para_model = keras.utils.multi_gpu_model(model, gpus=gpu_num) # 默认在GPU上初始化多GPU模型
	para_model.compile(optimizer, loss=[...], metrics=[...]) # 只编译多GPU模型
else: # 单GPU
	model = YourModel(input_size, num_classes) # 直接在GPU上初始化单GPU模型
	model.load_weights('*.h5') # 加载权重
	model.compile(optimizer, loss=[...], metrics=[...]) # 编译模型
	
# 训练和验证
def train():
	if gpu_num>=2:
		para_model.train_on_batch(...)
	else:
		model.train_on_batch(...)
		
def evaluate():
	if gpu_num>=2:
		para_model.test_on_batch(...)
	else:
		model.test_on_batch(...)
		
# 保存模型,不管是单GPU还是多GPU,只需要对model做保存操作
# 不要使用 para_model.save() 或者 para_model.save_weights(),否则加载时会出问题
model.save('*.h5')
model.save_weights('*.h5')

3. keras和torch的结合

torch 的 dataloader 是最好用的数据加载方式,使用 train_on_batch 一部分的原因是能够用 torch dataloader 载入数据,然后用 train_on_batch 对模型进行训练,通过合理的控制 cpu worker 的使用个数和 batch_size 的大小,使模型的训练效率最大化

3.1 dataloader+train_on_batch 训练keras模型pipeline

# 定义 torch dataset
class Dataset(torch.utils.data.Dataset):
	def __init__(self, root_list, transforms=None):
		self.root_list = root_list
		self.transforms = transforms
		
	def __getitem__(self, idx):
		# 假设是图像分类任务
		image = ... # 读取单张图像
		label = ... # 读取标签
		if self.transforms is not None:
			image = self.transforms(image)
		return image, label # shape: (H,W,3), salar
		
	def __len__(self):
		return len(self.root_list)
		
# 自定义 collate_fn 使 dataloader 返回 numpy array
def collate_fn(batch):
	# 这里的 batch 是 tuple 列表,[(image, label),(image, label),...]
	image, label = zip(*batch)
	image = np.asarray(image) # (batch_size, H, W, 3)
	label = np.asarray(label) # (batch_size)
	return image, label # 如果 datast 返回的图像是 ndarray,这样loader返回的也是 ndarray
	
# 定义dataset
train_dataset = Dataset(train_list)
valid_dataset = Dataset(valid_list)
test_dataset = Dataset(test_list)

# 定义 dataloader, 如果不使用自定义 collate_fn,
# 从 loader 取出的默认是 torch Tensor,需要做一个 .numpy()的转换
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

# 定义 train,evaluate,test
def train():
	for i,(inputs, label) in enumerate(train_loader):
		# 如果 inputs 和 label 是 torch Tensor
		# 请用 inputs = inputs.numpy() 和 label = label.numpy() 转成 ndarray
		y_pred = model.train_on_batch(inputs, label)
		
def evaluate():
	for i,(inputs, label) in enumerate(valid_loader):
		# 如果 inputs 和 label 是 Tensor,同上
		y_pred = model.test_on_batch(inputs, label)
		
def test():
	for i,(inputs, label) in enumerate(test_loader):
		# 如果 inputs 和 label 是 Tensor,同上
		y_pred = model.test_on_batch(inputs, label)
		
def run():
	for epoch in num_epoch:
		train()
		evaluate()
	test()
	
if __name__ == "__main__":
	run()

总结

还有一些使用 train_on_batch 的地方比如 GAN 的训练,这里就不介绍了,具体可以上 github 上搜索例子。

参考

Keras documentation: Model training APIs​keras.io
5f419a29bb49e5677ab443a36863b6b2.png

注:以上资源均来自互联网

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值