就是高频词用比如1024或512维,低频词用256或64维。再用Linear层project到相同的维数:
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token # 793470
self.d_embed = d_embed # 1024
self.cutoffs = cutoffs + [n_token] # [60000, 100000, 640000, 793470]
self.div_val = div_val # 4
self.d_proj = d_proj # 1024
self.emb_scale = d_proj ** 0.5 # 32
self.cutoff_ends = [0] + self.cutoffs # [0, 60000, 100000, 640000, 793470]
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val &#