2021SC@SDUSC
1、class Transition
class Transition(hk.Module):
def __init__(self, config, global_config, name='transition_block'):
super().__init__(name=name)
self.config = config
self.global_config = global_config
def __call__(self, act, mask, is_training=True):
_, _, nc = act.shape
num_intermediate = int(nc * self.config.num_intermediate_factor)
mask = jnp.expand_dims(mask, axis=-1)
act = hk.LayerNorm(
axis=[-1],
create_scale=True,
create_offset