T5为啥没法用fp16训练
最近基于T5做一些分类任务的训练,发现使用混合精度训练,无论哪个型号的mT5,必出现输出为nan的问题。
调试的时候发现,模型在Encoder阶段还算正常,在Decoder阶段,模型计算的过程中很容易出现inf,然后进而导致在后续的计算过程中出现nan。
目前发现在T5的self Attention的计算过程中,会出现-inf的情况,主要是Attention的分数在进行mask的过程中出现的,在transformers的实现中,Attention的mask来自get_extended_attention_mask
,这里返回的mask的数据类型是fp32,这就导致fp16格式的scores会出现-inf的问题:
T5 self-Attention的部分代码
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
# 这行是重点
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
但是即使修改了get_extended_attention_mask
返回的mask为fp16,或者mask的数值改成-32752.0(fp16可以表示的最小数字的一半),在计算过程中依旧会出现-inf的问题,目测是在经过一些线性层的时候,出现的这个问题,毕竟fp16的可表示范围小了很多。
所以-inf的问题难以解决,然后在经过MT5LayerFF的时候,其中会经过一个激活函数,是transformer中自定义的一个NewGELUActivation
激活函数,其中会出现inf * 0的情况,导致了出现nan的问题。
其实看T5的代码,里面有想解决这个inf问题的一些代码,比如在MT5Block
中有下面这段代码:
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
这段代码前半部分是经过Decoder的第一层multi-head-Attention,然后输出的结果为hidden_states,正常来说如果hidden_states为fp16,那么会经过下面的剪切代码,防止出现inf的情况,然后再经过MT5LayerFF
层就不会那么容易出现nan的问题。但是可惜的是,第一层multi-head-Attention的实现里面MT5LayerSelfAttention
,最后输出是残差结构,而残差是加法操作,在混合精度计算里面,加法操作用fp32进行计算,这就导致Decoder的第一层multi-head-Attention的输出的结果为fp32,这样就没有经过下面的针对fp16的裁剪。目前看这个针对fp16的裁剪可能在推理的时候有点用。
同时根据网上搜集的资料:
https://github.com/huggingface/transformers/issues/23918
https://www.philschmid.de/fine-tune-flan-t5-deepspeed
https://github.com/huggingface/transformers/issues/4586
https://github.com/huggingface/transformers/issues/10830
https://github.com/huggingface/transformers/pull/10956
发现这个可能并不是T5的问题,而是使用bf16训练的模型,然后使用fp16进行训练的通病,因为bf16比fp16表示的整数范围更大,和fp32一致,这就导致bf16转到fp16的时候出现问题,导致模型的数据出现异常。目前针对这个问题的比较好的解决办法就是,如果显卡支持bf16,可以使用bf16进行计算,否则使用fp32进行计算