Adaptive input representations源码阅读笔记

一.参考链接

二.代码

import torch.nn as nn

class AdaptiveInput(nn.Module):
    """
    This implementation and the above description are heavily cited from the softmax counterpart from
    https://pytorch.org/docs/stable/_modules/torch/nn/modules/adaptive.html
    """
    def __init__(self, in_features, n_classes, cutoffs=None,div_value=4., head_bias=False):
        """
        :param in_features:embeding的维度论文中的参数d
        :param n_classes:词汇表长度
        :param cutoffs: 是一个list,控制词汇表的划分,例如:cutoffs = [10, 100, 1000]
        表示将词汇表划分成五个clusters:[0-10]、[11-100]、[101-1000]和[1001-最后]
        :param div_value:就是论文中的参数 k
        :param head_bias:
        """
        super(AdaptiveInput, self).__init__()#初始化父类
        if not cutoffs:
            cutoffs = [10000, 60000, 190000]
        cutoffs = list(cutoffs)
        #检查cutoffs是否符合条件
        if (cutoffs != sorted(cutoffs)) \
                or (min(cutoffs) <= 0) \
                or (max(cutoffs) >= (n_classes - 1)) \
                or (len(set(cutoffs)) != len(cutoffs)) \
                or any([int(c) != c for c in cutoffs]):
            raise ValueError("cutoffs should be a sequence of unique, positive "
                             "integers sorted in an increasing order, where "
                             "each value is between 1 and n_classes-1")

        self.in_features = in_features
        self.n_classes = n_classes
        self.cutoffs = cutoffs + [n_classes]#在cutoffs后面插入词汇表的最大值
        self.div_value = div_value
        self.head_bias = head_bias
        #这里将词汇表的第一个子集V1c称为head,其他的vi称为cluster
        self.n_clusters = len(self.cutoffs) - 1 #cluster的个数
        self.head_size = self.cutoffs[0]  #V1子集的大小
        #定义V1的embeding矩阵E1与映射矩阵W1
        self.head = nn.Sequential(nn.Embedding(self.head_size, self.in_features),
                                  nn.Linear(self.in_features, self.in_features, bias=self.head_bias))
        #其他的Vi的embedding矩阵Ei与映射矩阵Wi,放入该列表中
        self.tail = nn.ModuleList()
        for i in range(self.n_clusters):
            hsz = int(self.in_features // (self.div_value ** (i + 1))) #Ei维度
            osz = self.cutoffs[i + 1] - self.cutoffs[i] #Vi中的词汇数
            #定义Ei与Wi
            projection = nn.Sequential(
                nn.Embedding(osz, hsz),
                nn.Linear(hsz, self.in_features, bias=False),
            )
            #添加到ModuleList中去
            self.tail.append(projection)


    def forward(self, input):
        """
        :param input: 一个句子list,中间元素是句子中词在词汇表中的编号
        """
        used_rows = 0
        input_size = list(input.size()) #[q_len]

        output = input.new_zeros(input_size + [self.in_features]).float() #[q_len,in_features]

        cutoff_values = [0] + self.cutoffs
        for i in range(len(cutoff_values) - 1):

            low_idx = cutoff_values[i] #Vi 第一个词的索引
            high_idx = cutoff_values[i + 1]  #Vi 中最后一个词的索引

            input_mask = (input >= low_idx) & (input < high_idx) #将句子中属于Vi的词标为1,其他标为0
            row_indices = input_mask.nonzero().squeeze() #取出input_mask中个行中为1的列索引

            if row_indices.numel() == 0:#句子中没有Vi中的词
                continue
            #去除句子中属于Vi的词并输入
            out = self.head(input[input_mask] - low_idx) if i == 0 else self.tail[i - 1](input[input_mask] - low_idx)
            output.index_copy_(0, row_indices, out) #按照源句子中的位置将词embedding向量放到输出矩阵的相应位置。
            used_rows += row_indices.numel()  #记录已经有多少个词一个转化为embedding向量

        if used_rows != input_size[0]:
            raise RuntimeError("Target values should be in [0, {}], "
                               "but values in range [{}, {}] "
                               "were found. ".format(self.n_classes - 1,
                                                     input.min().item(),
                                                     input.max().item()))
        return output


# Example
import torch
x = torch.arange(0,100).long()
inp = AdaptiveInput(128, 100, cutoffs=[4,8,16])
print(inp(x))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值