工作记忆模型

本文主要graph.py代码进行注释。

def get_non_pad_mask(seq, pad_idx, device):
    # seq: [B, L]
    assert seq.dim() == 2
    # [B, L]
    mask = seq.ne(pad_idx).type(torch.float)
    return mask.to(device)


def get_seq_length(seq, pad_idx, device):
    mask = get_non_pad_mask(seq, pad_idx, device)
    # mask: [B, T]
    lengths = mask.sum(dim=-1).long()
    return lengths

这是两个用于处理序列数据的辅助函数,通常用于深度学习模型中。

  1. get_non_pad_mask(seq, pad_idx, device):

    • 输入参数:
      • seq: 一个形状为 [B, L] 的二维张量,表示一个批次的序列数据,其中 B 是批次大小,L 是最大序列长度。
      • pad_idx: 用于表示填充标记(通常是0)的整数值,它用来标识序列中的填充部分。
      • device: 表示计算设备(CPU或GPU)。
    • 功能:生成一个与输入序列相同形状的二进制掩码,用于标识哪些元素不是填充。掩码中的非填充元素将被标记为1,而填充元素将被标记为0。
    • 返回值:一个与输入 seq 形状相同的张量,其中非填充元素为1,填充元素为0,并且该张量位于指定的 device 上。
  2. get_seq_length(seq, pad_idx, device):

    • 输入参数:
      • seq: 一个形状为 [B, L] 的二维张量,表示一个批次的序列数据,其中 B 是批次大小,L 是最大序列长度。
      • pad_idx: 用于表示填充标记(通常是0)的整数值,它用来标识序列中的填充部分。
      • device: 表示计算设备(CPU或GPU)。
    • 功能:计算每个序列在不包含填充部分的情况下的实际长度。它利用 get_non_pad_mask 函数生成的掩码,对每个序列进行求和,以确定每个序列的有效长度。
    • 返回值:一个形状为 [B] 的张量,其中每个元素表示对应序列的有效长度。该张量位于指定的 device 上。

这些函数通常用于在处理序列数据时,将填充的部分排除在计算和处理之外,以便模型能够正确处理不同长度的序列。

下面代码定义了一个名为 WorkingMemoryModel 的PyTorch模型类,该模型用于工作记忆和生成古诗。

class WorkingMemoryModel(nn.Module):
    def __init__(self, hps, device):
        super(WorkingMemoryModel, self).__init__()
        self.hps = hps
        self.device = device

        self.global_trace_size = hps.global_trace_size
        self.topic_trace_size = hps.topic_trace_size
        self.topic_slots = hps.topic_slots
        self.his_mem_slots = hps.his_mem_slots

        self.vocab_size = hps.vocab_size
        self.mem_size = hps.mem_size

        self.sens_num = hps.sens_num

        self.pad_idx = hps.pad_idx
        self.bos_tensor = torch.tensor(hps.bos_idx, dtype=torch.long, device=device)

        # ----------------------------
        # build componets
        self.layers = nn.ModuleDict()
        self.layers['word_embed'] = nn.Embedding(hps.vocab_size,
            hps.word_emb_size, padding_idx=hps.pad_idx)

        # NOTE: We set fixed 33 phonology categories: 0~32
        #   please refer to preprocess.py for more details
        self.layers['ph_embed'] = nn.Embedding(33, hps.ph_emb_size)

        self.layers['len_embed'] = nn.Embedding(hps.sen_len, hps.len_emb_size)


        self.layers['encoder'] = BidirEncoder(hps.word_emb_size, hps.hidden_size, drop_ratio=hps.drop_ratio)
        self.layers['decoder'] = Decoder(hps.hidden_size, hps.hidden_size, drop_ratio=hps.drop_ratio)

        # project the decoder hidden state to a vocanbulary-size output logit
        self.layers['out_proj'] = nn.Linear(hps.hidden_size, hps.vocab_size)

        # update the context vector
        self.layers['global_trace_updater'] = ContextLayer(hps.hidden_size, hps.global_trace_size)
        self.layers['topic_trace_updater'] = MLP(self.mem_size+self.topic_trace_size,
            layer_sizes=[self.topic_trace_size], activs=['tanh'], drop_ratio=hps.drop_ratio)


        # MLP for calculate initial decoder state
        self.layers['dec_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)
        self.layers['key_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

        # history memory reading and writing layers
        # query: concatenation of hidden state, global_trace and topic_trace
        self.layers['memory_read'] = AttentionReader(
            d_q=hps.hidden_size+self.global_trace_size+self.topic_trace_size+self.topic_slots,
            d_v=hps.mem_size, drop_ratio=hps.attn_drop_ratio)

        self.layers['memory_write'] = AttentionWriter(hps.mem_size+self.global_trace_size, hps.mem_size)

        # NOTE: a layer to compress the encoder hidden states to a smaller size for larger number of slots
        self.layers['mem_compress'] = MLP(hps.hidden_size*2, layer_sizes=[hps.mem_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

        # [inp, attns, ph_inp, len_inp, global_trace]
        self.layers['merge_x'] = MLP(
            hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size+hps.global_trace_size+hps.mem_size,
            layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)


        # two annealing parameters
        self._tau = 1.0
        self._teach_ratio = 0.8


        # ---------------------------------------------------------
        # only used for for pre-training
        self.layers['dec_init_pre'] = MLP(hps.hidden_size*2,
            layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

        self.layers['merge_x_pre'] = MLP(
            hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size,
            layer_sizes=[hps.hidden_size],
            activs=['tanh'], drop_ratio=hps.drop_ratio)

以下是一些重要组件和属性的解释:

  1. WorkingMemoryModel 类是一个PyTorch模型,继承自nn.Module

  2. __init__ 构造函数中,初始化了模型的各种参数和组件:

    • hps 是一个参数配置对象,包含模型的各种超参数(例如,词汇大小、嵌入维度、隐藏层大小等)。
    • device 表示计算设备(CPU或GPU)。
    • 其他属性包括模型的输入和输出尺寸、填充标记、起始标记等。
  3. 模型中的各个组件(如嵌入层、编码器、解码器、注意力机制等)都在 self.layers 中被定义为PyTorch模块。

  4. 以下是一些主要组件的解释:

    • word_embed 是词嵌入层,用于将词汇索引映射为词嵌入向量。
    • ph_embed 是一个特殊的嵌入层,用于将音韵类别索引映射为音韵嵌入向量。
    • len_embed 是一个嵌入层,用于将句子长度映射为长度嵌入向量。
    • encoder 是一个双向编码器。
    • decoder 是解码器。
    • out_proj 是一个线性层,用于将解码器的隐藏状态映射为词汇表大小的输出概率。
    • global_trace_updatertopic_trace_updater 是用于更新工作记忆中的全局痕迹和主题痕迹的组件。
    • memory_readmemory_write 是用于读取和写入工作记忆的注意力读写器。
    • mem_compress 是一个用于将编码器隐藏状态压缩到较小尺寸的层。
    • merge_x 是一个用于合并各种输入特征的多层感知机。
  5. 模型中还包括一些用于预训练的组件,例如 dec_init_premerge_x_pre

  6. 模型还定义了两个参数 _tau_teach_ratio,这些参数可能用于模型的训练或控制训练过程的某些超参数。

这段代码描述了一个复杂的深度学习模型,用于生成古诗,并包含了许多不同的组件和层,这些组件和层在生成过程中协同工作。模型的训练和使用将涉及到多个步骤和组件的协同操作。

    def set_tau(self, tau):
        if 0.0 < tau <= 1.0:
            self.layers['memory_write'].set_tau(tau)

    def get_tau(self):
        return self.layers['memory_write'].get_tau()

    def set_teach_ratio(self, teach_ratio):
        if 0.0 < teach_ratio <= 1.0:
            self._teach_ratio = teach_ratio

    def get_teach_ratio(self):
        return self._teach_ratio


    def set_null_idxes(self, null_idxes):
        self.null_idxes = null_idxes.to(self.device).unsqueeze(0)

这部分代码定义了一些方法,用于设置和获取模型的超参数和状态信息:

  1. set_tau(self, tau) 方法用于设置模型中的 tau 参数,该参数可能用于控制模型中的某些行为。tau 应该是一个介于0.0和1.0之间的浮点数,用于设置 memory_write 层的 tau 参数。

  2. get_tau(self) 方法用于获取当前模型中 tau 参数的值。

  3. set_teach_ratio(self, teach_ratio) 方法用于设置模型中的 _teach_ratio 参数,该参数可能用于控制训练过程中的教师强制比率。teach_ratio 应该是一个介于0.0和1.0之间的浮点数。

  4. get_teach_ratio(self) 方法用于获取当前模型中 _teach_ratio 参数的值。

  5. set_null_idxes(self, null_idxes) 方法用于设置模型中的 null_idxes,这可能是一些索引值。这些索引值被设置为模型的属性 null_idxes 并将其移到模型的计算设备(device)上。null_idxes 通常是一个张量,包含模型用来表示空白或未定义值的特殊索引。

这些方法提供了一种设置和获取模型超参数以及其他重要属性的方式,以便在训练和使用模型时能够调整和控制模型的行为。

 def compute_null_mem(self, batch_size):
        # we initialize the null memory slot with an average of stop words
        #    by supposing that the model could learn to ignore these words
        emb_null = self.layers['word_embed'](self.null_idxes)

        # (1, L, 2*H)
        enc_outs, _ = self.layers['encoder'](emb_null)

        # (1, L, 2 * H) -> (1, L, D)
        null_mem = self.layers['mem_compress'](enc_outs)

        # (1, L, D)->(1, 1, D)->(B, 1, D)
        self.null_mem = null_mem.mean(dim=1, keepdim=True).repeat(batch_size, 1, 1)

这段代码是用于初始化“空白内存”槽的函数 'compute_null_mem'。让我来解释它的主要步骤:

1. 'emb_null = self.layers ['word_embed'](self.null_idxes)':首先,通过模型的 'word_embed' 组件,将名为 'self.null_idxes' 的张量(包含空白索引)进行词嵌入操作。这将把空白索引转换为嵌入的词向量。

2. 'enc_outs, _ = self.layers ['encoder'](emb_null)':接下来,将嵌入的词向量输入到模型的编码器(encoder)中,'enc_outs' 会包含编码的信息,但在这段代码中不会被使用,因此用下划线 '_' 表示不需要返回的值。

3. 'null_mem = self.layers['mem_compress'](enc_outs)':然后,对编码器的输出进行进一步的处理,通过模型的 'mem_compress' 组件,将其维度从(1,L,2*H)压缩到(1,L,D),其中 'D' 是内存槽的维度。

4. 'self.null_mem = null_mem.mean(dim=1, keepdim=True).repeat(batch_size, 1, 1)':最后,计算 'null_mem' 的均值,将其维度从(1,L,D)变为(1,1,D),然后通过 '.repeat' 操作,将其复制为 'batch_size' 份,以初始化每个批次的“空白内存”槽。

这个过程旨在通过对"空白索引"进行嵌入和处理,为模型提供一个初始的内存槽,这个内存槽中包含了一种表示"模型应该忽略的词"的信息。这有助于模型学习在生成文本时忽略不必要的词语。

 def computer_topic_memory(self, keys):
        # (B, key_len)
        emb_keys = [self.layers['word_embed'](key) for key in keys]
        key_lens = [get_seq_length(key, self.pad_idx, self.device) for key in keys]

        batch_size = emb_keys[0].size(0)

        # length == 0 means this is am empty topic slot
        topic_mask = torch.zeros(batch_size, self.topic_slots,
            dtype=torch.float, device=self.device).bool() # (B, topic_slots)
        for step in range(0, self.topic_slots):
            topic_mask[:, step] = torch.eq(key_lens[step], 0)


        key_states_vec, topic_slots = [], []
        for step, (emb_key, length) in enumerate(zip(emb_keys, key_lens)):

            # we set the length of empty keys to 1 for parallel processing,
            #   which will be masked then for memory reading
            length.masked_fill_(length.eq(0), 1)

            _, state = self.layers['encoder'](emb_key, length)
            # (2, B, H) -> (B, 2, H) -> (B, 2*H)
            key_state = state.transpose(0, 1).contiguous().view(batch_size, -1)
            mask = (1 - topic_mask[:, step].float()).unsqueeze(1) # (B, 1)

            key_states_vec.append((key_state*mask).unsqueeze(1))

            topic = self.layers['mem_compress'](key_state)
            topic_slots.append((topic*mask).unsqueeze(1))

        # (B, topic_slots, mem_size)
        topic_mem = torch.cat(topic_slots, dim=1)

        # (B, H)
        key_init_state = self.layers['key_init'](
            torch.cat(key_states_vec, dim=1).sum(1))

        return topic_mem, topic_mask, key_init_state

这段代码是用于计算主题记忆的函数 'computer_topic_memory'。以下是它的主要步骤:

1. 'emb_keys = [self.layers['word_embed'](key) for key in keys]':首先,将给定的主题文本索引 'keys' 转换为词嵌入向量列表 'emb_keys',使用模型的 'word_embed' 组件进行词嵌入。

2. 'key_lens = [get_seq_length(key, self.pad_idx, self.device) for key in keys]':计算每个主题文本的长度,主要用于后续处理。

3. 创建一个名为 'topic_mask' 的布尔型张量,其维度是(B,topic_slots),其中 'B' 是批次大小,'topic_slots' 是主题内存槽的数量。

4. 遍历主题内存槽,并根据每个主题文本的长度是否为0,设置相应的 'topic_mask' 值。如果长度为0,说明这是一个空的主题槽,对应位置的 'topic_mask' 设置为True,否则为False。

5. 初始化两个列表 'key_states_vec' 和 'topic_slots',用于存储主题内存槽的相关信息。

6. 遍历每个主题文本,对于每个主题文本,进行以下操作:

- 通过编码器 'encoder' 处理嵌入的主题文本 'emb_key',并使用 'length' 来指示文本的长度。

- 从编码器输出中提取状态信息 'state'。

- 调整 'state' 的维度,将其从(2,B,H)转置为(B,2*H),然后使用 'concontinuguous' 确保内存连续。

- 创建一个用于屏蔽(mask)的张量,其中非空的主题槽位置对应的mask值为0(False),表示要考虑该槽的信息。

- 将处理后的 'key_state' 和 'topic' 分别添加到 'key_states_vec' 和 'topic_slots' 中。

7. 将所有主题槽的信息拼接在一起,形成主题内存 'topic_mem',其维度为(B,topic_slots,mem_size)。

8. 计算主题内存的初始状态 'key_init_state',通过对 'key_states_vec' 中的所有信息求和,然后将其输入到 'key_init' 组件中。

最终,该函数返回计算得到的主题内存 'topic_mem',主题内存槽的遮罩 'topic_mask',以及主题内存的初始状态 'key_init_state'。这些信息将在模型中用于生成与主题相关的

 def computer_local_memory(self, inps, with_length):
        batch_size = inps.size(0)
        if with_length:
            length = get_seq_length(inps, self.pad_idx, self.device)
        else:
            length = None

        emb_inps = self.layers['word_embed'](inps)

        # outs: (B, L, 2 * H)
        # states: (2, B, H)
        enc_outs, enc_states = self.layers['encoder'](emb_inps, length)

        init_state = self.layers['dec_init'](enc_states.transpose(0, 1).
            contiguous().view(batch_size, -1))

        # (B, L, 2 * H) -> (B, L, D)
        local_mem = self.layers['mem_compress'](enc_outs)

        local_mask = torch.eq(inps, self.pad_idx)

        return local_mem, local_mask, init_state

这段代码是用于计算局部记忆(local memory)的函数 'computer_local_memory'。以下是它的主要步骤:

1. 'batch_size = inps.size(0)':获取输入文本的批次大小 'batch_size'。

2. 'if with_length:':检查是否需要考虑输入文本的长度信息。如果 'with_length' 为 True,表示需要考虑长度信息。

3. 如果需要考虑长度信息,那么通过 'get_seq_length' 函数计算输入文本 'inps' 中的每个句子的长度,并将结果存储在 'length' 变量中。否则,将 'length' 设置为 'None',表示不考虑长度信息。

4. 通过词嵌入组件 'word_embed',将输入文本 'inps' 转换为嵌入向量 'emb_inps'。这里是将文本中的词汇索引转换为词嵌入向量。

5. 使用编码器 'encoder' 处理嵌入的输入文本 'emb_inps',同时传入文本的长度信息 'length'。编码器会生成两个输出:'enc_outs' 表示编码器的输出,维度为(B,L,2 * H),其中 B 是批次大小,L 是句子长度,2 * H 表示隐藏状态的维度。'enc_states' 表示编码器的最终状态,维度为(2,B,H),其中 2 表示双向编码器的状态。

6. 通过初始化状态的组件 'dec_init',初始化初始状态 'init_state'。首先,将 'enc_states' 进行转置操作,从(2,B,H)转置为(B,2,H),然后使用 'contiguous' 以确保内存连续。接着,将其展平为维度为(batch_size,-1)的张量。

7. 使用 'mem_compress' 组件,将编码器的输出 'enc_outs' 转换为本地内存 'local_mem',其维度为(B,L,D),其中 D 表示本地内存的维度。这一步将编码器输出的隐藏状态进行压缩。

8. 创建局部记忆的掩盖 'local_mask',用于标记输入文本中的填充位置(pad_idx)。'local_mask' 的维度与输入文本相同,对于填充位置的元素为True,其他位置为False。

最终,该函数返回计算得到的局部记忆 'local_mem'、局部记忆掩盖 'local_mask',以及初始状态 'init_state'。局部记忆将在后续模型中用于生成与输入文本相关的文本。

    def update_global_trace(self, old_global_trace, dec_states, dec_mask):
        states = torch.cat(dec_states, dim=2) # (B, H, L)
        global_trace = self.layers['global_trace_updater'](
            old_global_trace, states*(dec_mask.unsqueeze(1)))
        return global_trace


    def update_topic_trace(self, topic_trace, topic_mem, concat_aligns):
        # topic_trace: (B, topic_trace_size+topic_slots)
        # concat_aligns: (B, L_gen, mem_slots)

        # 1: topic memory, 2: history memory 3: local memory
        topic_align = concat_aligns[:, :, 0:self.topic_slots].mean(dim=1) # (B, topic_slots)

        # (B, topic_slots, mem_size) * (B, topic_slots, 1) -> (B, topic_slots, mem_size)
        #   -> (B, mem_size)
        topic_used = torch.mul(topic_mem, topic_align.unsqueeze(2)).mean(dim=1)


        new_topic_trace = self.layers['topic_trace_updater'](
            torch.cat([topic_trace[:, 0:self.topic_trace_size], topic_used], dim=1))

        read_log = topic_trace[:, self.topic_trace_size:] + topic_align

        fin_topic_trace = torch.cat([new_topic_trace, read_log], dim=1)

        return fin_topic_trace

这段代码包括两个函数:update_global_traceupdate_topic_trace,它们用于更新全局追踪(global trace)和主题追踪(topic trace)。

update_global_trace函数接受以下参数:

  • old_global_trace:先前的全局追踪状态,维度为(B,global_trace_size)。
  • dec_states
  • dec_mask:解码器的掩码,维度为(B,L)。

update_global_trace函数的操作如下:

  1. 首先,使用 'torch.cat
  2. 接着,通过 组件来更新全局追踪状态,传入旧的全局追踪状态global_trace_updaterold_global_trace和解码器隐藏状态states,但只考虑有效时间步的部分,因此乘以dec_mask进行掩码操作。
  3. 最终,函数返回更新后的全局追踪状态global_trace

update_topic_trace函数接受以下参数:

  • topic_trace:当前的主题追踪状态,维度为(B,topic_trace_size+topic_slots)。
  • topic_mem
  • 'concat_aligns

函数的操作如下:

  1. 首先,计算主题对齐(topic_align),从 'concat_aligns
  2. 接着,将主题内存topic_mem与主题对齐topic_align相乘,再取平均,得到topic_used,表示主题内存中被使用的部分。topic_used的维度为(B,mem_size)。
  3. 使用 'topic_trace_updater
  4. 计算主题追踪的读取日志(read_log),将它与主题追踪中已有的部分相加,得到 '
  5. 最终,函数返回更新后的主题追踪状态 '

这两个函数在模型中用于更新全局追踪和主题追踪,这些状态用于生成与输入文本相关的文本内容。

    def dec_step(self, inp, state, ph, length, total_mem, total_mask,
        global_trace, topic_trace):

        emb_inp = self.layers['word_embed'](inp)
        emb_ph = self.layers['ph_embed'](ph)
        emb_len = self.layers['len_embed'](length)

        # query for reading read memory
        # (B, 1, H]
        query = torch.cat([state, global_trace, topic_trace], dim=1).unsqueeze(1)

        # attns: (B, 1, mem_size), align: (B, 1, L)
        attns, align = self.layers['memory_read'](query, total_mem, total_mem, total_mask)


        x = torch.cat([emb_inp, emb_ph, emb_len, attns, global_trace], dim=1).unsqueeze(1)
        x = self.layers['merge_x'](x)

        cell_out, new_state = self.layers['decoder'](x, state)
        out = self.layers['out_proj'](cell_out)
        return out, new_state, align


    def run_decoder(self, inps, trgs, phs, lens, key_init_state,
        history_mem, history_mask, topic_mem, topic_mask, global_trace, topic_trace,
        specified_teach_ratio):

        local_mem, local_mask, init_state = \
            self.computer_local_memory(inps, key_init_state is None)

        if key_init_state is not None:
            init_state = key_init_state

        if specified_teach_ratio is None:
            teach_ratio = self._teach_ratio
        else:
            teach_ratio = specified_teach_ratio


        # Note this order: 1: topic memory, 2: history memory 3: local memory
        total_mask = torch.cat([topic_mask, history_mask, local_mask], dim=1)
        total_mem = torch.cat([topic_mem, history_mem, local_mem], dim=1)

        batch_size = inps.size(0)
        trg_len = trgs.size(1)

        outs = torch.zeros(batch_size, trg_len, self.vocab_size,
            dtype=torch.float, device=self.device)

        state = init_state
        inp = self.bos_tensor.repeat(batch_size)
        dec_states, attn_weights = [], []

        # generate each line
        for t in range(0, trg_len):
            out, state, align = self.dec_step(inp, state, phs[:, t],
                lens[:, t], total_mem, total_mask, global_trace, topic_trace)
            outs[:, t, :] = out

            attn_weights.append(align)

            # teach force with a probability
            is_teach = random.random() < teach_ratio
            if is_teach or (not self.training):
                inp = trgs[:, t]
            else:
                normed_out = F.softmax(out, dim=-1)
                inp = normed_out.data.max(1)[1]

            dec_states.append(state.unsqueeze(2)) # (B, H, 1)
            attn_weights.append(align)



        # write the history memory
        if key_init_state is None:
            new_history_mem, _ = self.layers['memory_write'](history_mem, local_mem,
                1.0-local_mask.float(), global_trace, self.null_mem)
        else:
            new_history_mem = history_mem

        # (B, L)
        dec_mask = get_non_pad_mask(trgs, self.pad_idx, self.device)

        # update global trace vector
        new_global_trace = self.update_global_trace(global_trace, dec_states, dec_mask)


        # update topic trace vector
        # attn_weights: (B, 1, all_mem_slots) * L_gen
        concat_aligns = torch.cat(attn_weights, dim=1)
        new_topic_trace = self.update_topic_trace(topic_trace, topic_mem, concat_aligns)


        return outs, new_history_mem, new_global_trace, new_topic_trace

dec_step这个函数接受以下参数:

  • inp:输入的当前时间步的单词,维度为(B,1)。
  • state:解码器当前时间步的隐藏状态,维度为(B,H)。
  • ph:与当前时间步相关的某种词性(词性)信息,维度为(B,1)。
  • length:与当前时间步相关的句子长度信息,维度为(B,1)。
  • total_mem
  • total_mask
  • '全球
  • topic_trace:主题追踪,维度为(B,topic_trace_size+topic_slots)。

函数的操作如下:

  1. 使用词嵌入层分别嵌入当前时间步的输入单词inp、词性ph和句子长度length
  2. 创建查询向量query,将解码器当前隐藏状态state、全局追踪global_trace和主题追踪topic_trace拼接在一起,然后添加一个额外的维度,得到维度为(B,1,H)的查询向量。
  3. 使用 'memory
  4. 将各种嵌入向量、内存读取结果和全局追踪拼接在一起,构成输入x,然后传入merge_x组件进行处理。
  5. 将输入 和当前隐藏状态xstate传入解码器('decoderdecoder),得到单步的细胞输出cell_out和新的隐藏状态new_state
  6. 最后,将细胞输出cell_out传入输出层out_proj,得到当前时间步的输出out,同时返回新的隐藏状态new_state和对齐信息align

run_decoder这个函数执行整个解码过程,从输入数据 inps开始,生成整个句子的输出。函数的操作如下:

  1. 初始化总内存(总内存包括主题内存、历史内存和本地内存)和总掩码,将它们分别组合为 和 。total_memtotal_mask
  2. 创建输出张量 用于存储解码器的输出,维度为(B,trg_len,vocab_size),其中 B 是批次大小, 是目标句子的长度, 是outstrg_lenvocab_size
  3. 初始化当前时间步的输入单词inp和隐藏状态state
  4. 使用循环遍历每个时间步,调用 'dec_step
  5. 更新历史内存,根据是否有初始的关键内存 'key_init
  6. 获取解码器输出的掩码dec_mask
  7. 更新全局追踪new_global_trace,将其传入update_global_trace函数。
  8. 更新主题追踪 'new_topic_trace
  9. 返回模型的输出 、新的历史内存 、新的全局追踪outsnew_history_memnew_global_trace和新的主题追踪 'newnew_topic_trace
def initialize_mems(self, keys):
    batch_size = keys[0].size(0)
    topic_mem, topic_mask, key_init_state = self.computer_topic_memory(keys)

    history_mem = torch.zeros(batch_size, self.his_mem_slots, self.mem_size,
        dtype=torch.float, device=self.device)

    # default: True, masked
    history_mask = torch.ones(batch_size, self.his_mem_slots,
        dtype=torch.float, device=self.device).bool()

    global_trace = torch.zeros(batch_size, self.global_trace_size,
        dtype=torch.float, device=self.device)
    topic_trace = torch.zeros(batch_size, self.topic_trace_size+self.topic_slots,
        dtype=torch.float, device=self.device)

    self.compute_null_mem(batch_size)

    return topic_mem, topic_mask, history_mem, history_mask, global_trace, topic_trace, key_init_state

这段代码是一个初始化内存的函数initialize_mems,用于为模型的各个内存和追踪创建初始值。以下是函数的关键部分:这个函数接受一个名为 'keys

  1. 计算主题内存topic_mem、主题掩码 和关键内存的初始状态 'key_inittopic_maskkey_init_state,通过调用computer_topic_memory函数。
  2. 初始化历史内存history_mem,其维度为(batch_size,his_mem_slots,mem_size),并全部填充为零。
  3. 初始化历史内存的掩码 ,默认情况下全部为 True,即已掩盖。history_mask
  4. 初始化全局追踪 '
  5. 初始化主题追踪topic_trace,其维度为(batch_size,topic_trace_size+topic_slots),并全部填充为零。
  6. 调用 函数来计算并初始化空白内存(null memory)。compute_null_mem
  7. 返回初始化后的内存和追踪对象,包括主题内存topic_mem、主题掩码topic_mask、历史内存 history_mem、历史掩码history_mask、全局追踪global_trace、主题追踪 topic_trace和关键内存初始状态key_init_state。这些对象将在模型的解码过程中使用。
    def rebuild_inps(self, ori_inps, last_outs, teach_ratio):
        # ori_inps: (B, L)
        # last_outs: (B, L, V)
        inp_len = ori_inps.size(1)
        new_inps = torch.ones_like(ori_inps) * self.pad_idx

        mask = get_non_pad_mask(ori_inps, self.pad_idx, self.device).long()

        if teach_ratio is None:
            teach_ratio = self._teach_ratio

        for t in range(0, inp_len):
            is_teach = random.random() < teach_ratio
            if is_teach or (not self.training):
                new_inps[:, t] = ori_inps[:, t]
            else:
                normed_out = F.softmax(last_outs[:, t], dim=-1)
                new_inps[:, t] = normed_out.data.max(1)[1]

        new_inps = new_inps * mask

        return new_inps

这个函数用于重新构建输入序列,根据教师强制比率 '

  1. 获取原始输入序列 '
  2. 创建一个新的输入序列 'new
  3. 创建一个掩码 '
  4. 对于输入序列的每个时间步 't
  5. 最后,将新的输入序列 '

这个函数的目的是在训练期间执行教师强制,以指导模型生成输出。如果teach_ratio接近1.0,那么几乎每个时间步都将使用教师强制;如果teach_ratio接近0.0,那么几乎每个时间步都将根据模型的输出来生成输入。这种灵活性可以用来平衡模型的自由生成和保守生成,具体取决于所选择的教师强制比率。

    def forward(self, all_inps, all_trgs, all_ph_inps, all_len_inps, keys, teach_ratio=None,
        flexible_inps=False):
        '''
        all_inps: (B, L) * sens_num
        all_trgs: (B, L) * sens_num
        all_ph_inps: (B, L) * sens_num
        all_len_inps: (B, L) * sens_num
        keys: (B, L) * topic_slots
        flexible_inps: if apply partial teaching force to local memory.
            False: the ground-truth src line is stored into the local memory
            True: for local memory, ground-truth characters will be replaced with generated characters with
                the probability of 1- teach_ratio.
            NOTE: this trick is *not* adopted in our original paper, which could lead to
                better BLEU and topic relevance, but worse diversity of generated poems.
        '''
        all_outs = []

        topic_mem, topic_mask, history_mem, history_mask,\
            global_trace, topic_trace, key_init_state = self.initialize_mems(keys)

        for step in range(0, self.sens_num):
            if step > 0:
                key_init_state = None

            if step >= 1 and flexible_inps:
                inps = self.rebuild_inps(all_inps[step], all_outs[-1], teach_ratio)
            else:
                inps = all_inps[step]

            outs, history_mem, global_trace, topic_trace \
                = self.run_decoder(inps, all_trgs[step],
                    all_ph_inps[step], all_len_inps[step], key_init_state,
                    history_mem, history_mask, topic_mem, topic_mask,
                    global_trace, topic_trace, teach_ratio)

            if step >= 1:
                history_mask = history_mem.abs().sum(-1).eq(0) # (B, mem_slots)


            all_outs.append(outs)


        return all_outs

这个'forward'函数是该模型的主要前向传播方法,用于执行文本生成任务。以下是函数的主要步骤和参数说明:

  1. - 'all_inps': 输入序列,一个列表,每个元素都是形状为'(B, L)'的输入数据,其中'B'是批量大小,'L'是序列长度。'sens_num' 个输入序列按顺序排列,以处理多个敏感主题的文本生成。 - 'all_trgs': 目标输出序列,与 'all_inps' 具有相同的结构。
  2. - 'all_ph_inps': 与 'all_inps' 结构相同的辅助输入序列,用于指定生成的内容或标记。
  3. - 'all_len_inps': 与 'all_inps' 结构相同的输入序列,其中包含有关输入序列长度的信息。
  4. - 'keys': 与 'all_inps' 结构相同的输入序列,包含用于主题相关信息的关键。
  5. - 'teach_ratio': 教师强制比率,用于平衡生成模式。如果不提供,默认为 'None'。
  6. - 'flexible_inps': 一个布尔值,用于确定是否应用部分教师强制到本地内存中的输入序列。如果设置为 'True',将使用模型生成的字符替换本地内存中的地面真相字符,以概率 '1 - teach_ratio'。这是一个用于平衡生成多样性和准确性的技巧。

函数的主要执行步骤包括:

  1. 1. 初始化记忆:调用 'initialize_mems' 方法,初始化主题记忆、历史记忆和跟踪记忆。
  2. 2. 遍历 'sens_num' 次,每次处理一个输入序列。对于每个步骤:
  3. - 如果 'step > 0',将 'key_init_state' 设置为 'None',这将指示模型不再使用初始键状态。
  4. - 如果 'flexible_inps' 为 'True' 且 'step >= 1',则调用 'rebuild_inps' 方法,根据之前的输出和 'teach_ratio' 重建输入序列。否则,直接使用给定的输入序列。
  5. - 调用 'run_decoder' 方法,执行解码过程,生成输出序列。这将返回生成的输出、更新的历史内存、全局跟踪向量和主题跟踪向量。
  6. - 如果 'step >= 1',更新历史内存的掩码,用于标记不再需要的内存槽。
  7. 3. 返回 'all_outs',它是一个列表,包含每个输入序列的生成输出序列。

这个函数允许模型在多个输入序列上执行文本生成任务,支持教师强制比率的设置,并且可以选择是否使用部分教师强制来改善生成的多样性和相关性。

    def dseq_graph(self, inps, trgs, ph_inps, len_inps, teach_ratio=None):
        # pre-train the encoder and decoder as a denoising Seq2Seq model
        batch_size, trg_len = trgs.size(0), trgs.size(1)
        length = get_seq_length(inps, self.pad_idx, self.device)


        emb_inps = self.layers['word_embed'](inps)
        emb_phs = self.layers['ph_embed'](ph_inps)
        emb_lens = self.layers['len_embed'](len_inps)


        # outs: (B, L, 2 * H)
        # states: (2, B, H)
        _, enc_states = self.layers['encoder'](emb_inps, length)


        init_state = self.layers['dec_init_pre'](enc_states.transpose(0, 1).
            contiguous().view(batch_size, -1))


        outs = torch.zeros(batch_size, trg_len, self.vocab_size,
            dtype=torch.float, device=self.device)

        if teach_ratio is None:
            teach_ratio = self._teach_ratio

        state = init_state
        inp = self.bos_tensor.repeat(batch_size, 1)

        # generate each line
        for t in range(0, trg_len):
            emb_inp = self.layers['word_embed'](inp)
            x = self.layers['merge_x_pre'](torch.cat(
                [emb_inp, emb_phs[:, t].unsqueeze(1), emb_lens[:, t].unsqueeze(1)],
                dim=-1))

            cell_out, state, = self.layers['decoder'](x, state)
            out = self.layers['out_proj'](cell_out)

            outs[:, t, :] = out

            # teach force with a probability
            is_teach = random.random() < teach_ratio
            if is_teach or (not self.training):
                inp = trgs[:, t].unsqueeze(1)
            else:
                normed_out = F.softmax(out, dim=-1)
                top1 = normed_out.data.max(1)[1]
                inp  = top1.unsqueeze(1)


        return outs

这个'dseq_graph'函数用于在预训练阶段训练编码器和解码器,它实际上是一个自动编码器(Denoising Seq2Seq)模型。以下是函数的主要步骤和参数说明:

- 'inps': 输入序列,形状为'(B, L)',其中'B'是批量大小,'L'是输入序列的长度。

- 'trgs': 目标输出序列,与 'INPS' 具有相同的结构。

- 'ph_inps': 与 'inps' 结构相同的辅助输入序列,用于指定生成的内容或标记。

- 'len_inps': 与 'inps' 结构相同的输入序列,其中包含有关输入序列长度的信息。

- 'teach_ratio': 教师强制比率,用于平衡生成模式。如果不提供,默认为 'None'。

函数的主要执行步骤包括:

  1.  获取输入序列的长度信息,用于生成器的解码过程。
  2.  对输入序列、辅助输入序列和长度输入序列进行嵌入,以获取输入的词嵌入。
  3.  使用编码器('encoder')处理输入序列,生成编码器的输出和状态。这个编码器可以看作是一个特征提取器,将输入序列编码成一个表示。
  4.  初始化解码器的初始状态,这将作为解码过程的起点。
  5.  初始化一个输出张量 'outs',用于存储模型生成的输出。
  6.  遍历目标输出序列的每个时间步,执行解码过程。在每个时间步:  

- 获取当前时间步的目标输出标记 'trgs[:, t]'。

- 生成当前时间步的输入词嵌入 'emb_inp'。

- 构建解码器的输入张量 'x',该张量是当前时间步的输入词嵌入、辅助输入序列中的内容和长度输入序列中的信息的组合。

- 通过解码器('decoder')处理输入张量 'x' 和当前状态,获得解码过程的输出。

- 应用输出投影层('out_proj')获得最终的输出概率分布。 - 将当前时间步的输出添加到 'outs' 张量中。

- 使用教师强制比率('teach_ratio')来决定是否应用教师强制。如果是教师强制,将当前时间步的目标输出作为下一个时间步的输入;否则,将模型生成的概率分布中概率最高的标记作为下一个时间步的输入。

7. 返回输出张量 'outs',其中包含了模型生成的序列。

这个函数的目的是在预训练阶段将编码器和解码器一起训练,以捕获输入序列的信息并生成目标输出序列。

    def dseq_parameter_names(self):
        required_names = ['word_embed', 'ph_embed', 'len_embed',
            'encoder', 'decoder', 'out_proj',
            'dec_init_pre', 'merge_x_pre']
        return required_names

    def dseq_parameters(self):
        names = self.dseq_parameter_names()

        required_params = [self.layers[name].parameters() for name in names]

        return chain.from_iterable(required_params)

dseq_parameter_names函数返回了在 函数中使用的子模型的名称列表。这些子模型包括:dseq_graph

  • 'word_embed': 用于将输入序列中的词语映射为词嵌入的模型。
  • 'ph_embed': 用于将辅助输入序列中的内容映射为词嵌入的模型。
  • 'len_embed': 用于将长度输入序列中的信息映射为词嵌入的模型。
  • 'encoder': 编码器模型,用于将输入序列编码为隐藏表示。
  • 'decoder': 解码器模型,用于生成目标输出序列。
  • 'out_proj': 输出投影层,将解码器的输出映射到词汇表上的分布。
  • 'dec_init_pre': 解码器初始状态的初始化模型。
  • 'merge_x_pre': 用于将输入的不同部分合并为解码器的输入张量的模型。

dseq_parameters函数用于获取这些子模型的参数。它首先获取这些子模型的名称列表,然后对每个子模型,通过self.layers[name].parameters() 获取其参数。最后, chain.from_iterable(required_params)将这些参数连接成一个迭代器,以便进行优化和训练。这可以用于在模型训练时获取需要更新的参数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值