一文读懂—Pytiorch混合精度训练

> 复现代码时遇到了自动混合精度。查阅资料得知,Pytorch从1.60开始支持自动混合精度训练。其中自动、混合精度是两个关键词,那么代表什么意思呢?一起来看看吧!

🎏目录

    🎈1
    🎈2
      🎄2.1
      🎄2.2
      🎄2.3
      🎄2.4

✨1 混合精度训练简介

目前,Pytorch一共支持10种数据类型

  • torch.FloatTensor # 另一种表述:FP32**
  • torch.DoubleTensor # 64-bit floating point
  • torch.HalfTensor # 另一种表述:FP16
  • torch.ByteTensor
  • torch.CharTensor
  • torch.ShortTensor
  • torch.IntTensor
  • torch.LongTensor

默认使用的是32位浮点型精度的Tensor,即torch.FloatTensor 。因此,默认情况下我们训练的是一个FP32的模型。但不是所有数据都需要FP32那么大的内存。
此时,采用自动混合精度(Automatic Mixed Precision, AMP)训练,一部分算子数值精度为FP16,其余算子的数值精度是FP32,而哪些算子用FP16,哪些用FP32,由amp自动安排。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更多的 batch size、更大模型和尺寸更大的输入进行训练。

✨2 自动混合精度训练的使用

结合一段示例代码来看:

# amp依赖Tensor core架构,所以模型必须在cuda设备下使用
model = Model()
model.to("cuda")  # 必须!!!

optimizer = optim.SGD(model.parameters(), ...)

# (新增)创建GradScaler对象
scaler = GradScaler(enabled=True)  # 虽然默认为True,体验一下过程

for epoch in epochs:
    for img, target in data:
        optimizer.zero_grad()
        # (新增)启动autocast上下文管理器
        with autocast(enabled=True):
            # (不变)上下文管理器下,model前向传播,以及loss计算自动切换数值精度
            output = model(img)
            loss = loss_fn(output, target)
        # (修改)反向传播
        scaler.scale(loss).backward()
        # (修改)梯度计算
        scaler.step(optimizer)
        # (新增)scaler更新
        scaler.update()

使用自动化精度时,只有在模型以及损失计算,反向传播,梯度更新时作出一定的改变,具体有:

  1. scaler = GradScaler():创建对象GradScaler,并赋予变量scaler
  2. with autocast()::启动autocast上下文管理器,内含需要做精度放缩的计算(必须包含模型计算以及损失计算)
  3. scaler.scale(loss).backward():利用scaler做反向传播
  4. scaler.step(optimizer):梯度更新
  5. scaler.update():scaler更新

🎃 2.2 GradScaler

构造:

torch.cuda.amp.GradScaler(
	init_scale=65536.0,
	growth_factor=2.0,
	backoff_factor=0.5,
	growth_interval=2000,
	enabled=True,
)

这里形式很固定,只有一个参数enabled根据自己的需求进行改变

参数:

  1. enabled:是否做scale。如果为False,则返回原数据。如果为True,则进行一次精度转换。

原谅我,涉及到了原理,水平有限,真的看不懂,欢迎大家交流

🎉 2.2 autocast

autocast(
	enable=True  # 同上
)

上面的示例中,autocast是在训练脚本中使用的,除此之外还有两种方式:

  1. 作为装饰器,在forward函数中使用

==============================================================

class Model(nn.Module):

	def __init__(self):
		pass
	
	@torch.cuda.amp.autocast()  # autocast导入路径
	def forward():
		pass

==============================================================

  1. 在forward中使用上下文管理器

==============================================================

class Model(nn.Module):

	def __init__(self):
		pass
	
	def forward():
		with torch.cuda.amp.autocast():  # 上下文管理器
			pass

==============================================================

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

白三点

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值