PLUG AND PLAY LANGUAGE MODELS: A SIMPLE APPROACH TO CONTROL LEDTEXT(PPLM):代码深入理解(二)—PPLM_Discrim

代码链接:PPLM_code

二. Example command for discriminator based sentiment control

python run_pplm.py -D sentiment --class_label 2 --cond_text “My dog died” --length 50 --gamma 1.0 --num_iterations 10 --num_samples 10 --stepsize 0.04 --kl_scale 0.01 --gm_scale 0.95 --sample

代码逻辑——PPLM_Discrim:

1. run_pplm_example():

2. 加载模型,词典;对模型参数进行freeze

3. full_text_generation():返回生成结果:无扰动句子和扰动后句子;

3.1 generate_text_pplm(无扰动):目的是生成无扰动时的文本;
3.1.1:for i in range_func:每次更新一个词,设置句子长度30
  1. 如果输入3个词,这步只用前2个词,返回hidden

  2. 无扰动pert_past = past

  3. 传入last,和pert_past;返回pert_logits, pert_past, pert_all_hidden;进行past更新,加入当前last信息

  4. 根据pert_logits取出topk的词;

  5. 取出概率最大的词,当做下一个词的last,加入列表output_so_far

3.2:这样只采样生成5句扰动后句子:

3.2.1 generate_text_pplm(加入扰动)
3.2.1.1 for i in range_func:每次更新生成一个词
  1. 如果输入3个词,这步只用前2个词,返回hidden
  2. 输入全部三个词,返回无扰动:unpert_logits, unpert_past, unpert_all_hidden
  3. accumulated_hidden = unpert_last_hidden[:, :-1, :]; 只考虑前k-1个词;accumulated_hidden 再求和
  4. perturb_past():返回pert_past
4.1.初始化一个为0的grad_accumulator,size和24个k,v一起的一样(2,1,16,2,64)
4.2.初始化一个window_mask,size同(2,1,16,2,64)
4.3.进行三轮的梯度迭代:
	4.3.1. 把前面轮累计的梯度 grad_accumulator 加到 past 得到 perturbed_past; past 是已经更新后不进行梯度计算
	4.3.2. 把最近生成的词last 和 perturbed_past传入模型;得到all_logits和 all_hidden
	4.3.3. 从all_hidden取出根据last 生成的hidden 加到 accumulated_hidden 且不进行更新
	4.3.4. 取出根据last 从模型中生成的 logits 和probs;(1, 50257)
	++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
	4.3.5. 计算PPLM_DISCRIM loss
		4.3.5.1 初始化交叉熵loss,   取出前面无扰动时,输入全部初始词得到的unpert_past;
		4.3.5.2 for _ in range(horizon_length):  # horizon_length是1
		4.3.5.3. 把new_accumulated_hidden 除以 (curr_length + 1 + horizon_length)后传入分类器进行一个mlp 预测; 从(1,1024) -> (1,5)
		4.3.5.4. 构建交叉熵label(作者训练一个mlp作为分类器)
		4.3.5.5. 计算 discrim_loss()
	++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
	4.3.6. 根据unpert_probs 和 probs 计算 kl-loss
	4.3.7. 两个loss相加进行反向传播
	4.3.8. 用window_mask对curr_perturbation (实际上就是grad_accumulator ) 计算梯度norms ; 然后step size 计算 grad
	4.3.9. 把计算得grad 加到 grad_accumulator
	4.3.10. 用newpast 去取出pastnei
  1. 传入模型last,和pert_past;返回pert_logits, pert_past, pert_all_hidden;进行past更新,加入当前last信息
  2. 融合扰动后概率和扰动前概率;再取出top_k的词;rescale;得到新词典的概率分布
  3. 取出概率最大的词,当做下一个词的last
    print无扰动句子和扰动后的句子;

4. print无扰动句子和扰动后的句子;

代码如下:

#! /usr/bin/env python3
# coding=utf-8
# Copyright 2018 The Uber AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95

Example command with discriminator:
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
"""

import argparse
import json
from operator import add
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel

from pplm_classification_head import ClassificationHead

PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10

QUIET = 0
REGULAR = 1
VERBOSE = 2
VERY_VERBOSE = 3
VERBOSITY_LEVELS = {
   
    'quiet': QUIET,
    'regular': REGULAR,
    'verbose': VERBOSE,
    'very_verbose': VERY_VERBOSE,
}

BAG_OF_WORDS_ARCHIVE_MAP = {
   
    'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
    'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
    'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt",
    'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
    'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt",
    'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
    'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
    'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
    'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
}

DISCRIMINATOR_MODELS_PARAMS = {
   
    "clickbait": {
   
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
        "class_size": 2,
        "embed_size": 1024,
        "class_vocab": {
   "non_clickbait": 0, "clickbait": 1},
        "default_class": 1,
        "pretrained_model": "gpt2-medium",
    },
    "sentiment": {
   
        # "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
        'path' : '/home/xps/huanghong/workdir/PPLM-master/cache/SST_classifier_head.pt',
        "class_size": 5,
        "embed_size": 1024,
        "class_vocab": {
   "very_positive": 2, "very_negative": 3},
        "default_class": 3,
        "pretrained_model": "gpt2-medium",
    },
}


def to_var(x, requires_grad=False, volatile=False, device='cuda'):
    if torch.cuda.is_available() and device == 'cuda':
        x = x.cuda()
    elif device != 'cuda':
        x = x.to(device)
    return Variable(x, requires_grad=requires_grad, volatile=volatile)


def top_k_filter(logits, k, probs=False):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        if probs:
            return torch.where(logits < batch_mins,
                               torch.ones_like(logits) * 0.0, logits)
        return torch.where(logits < batch_mins,
                           torch.ones_like(logits) * -BIG_CONST,
                           logits)


def perturb_past(
        past,
        model,
        last,
        unpert_past=None,
        unpert_logits=None,
        accumulated_hidden=None,
        grad_norms=None,
        stepsize=0.01,
        one_hot_bows_vectors=None,
        classifier=None,
        class_label=None,
        loss_type=0,
        num_iterations=3,
        horizon_length=1,
        window_length=0,
        decay=False,
        gamma=1.5,
        kl_scale=0.01,
        device='cuda',
        verbosity_level=REGULAR
):
    # Generate inital perturbed past  # shape同24个block,k 和v  stack后的值
    grad_accumulator = [
        (np.zeros(p.shape).astype("float32"))
        for p in past
    ]

    if accumulated_hidden is None:
        accumulated_hidden = 0

    if decay:
        decay_mask = torch.arange(
            0.,
            1.0 + SMALL_CONST,
            1.0 / (window_length)
        )[1:]
    else:
        decay_mask = 1.0

    # TODO fix this comment (SUMANTH)
    # Generate a mask is gradient perturbated is based on a past window
    _, _, _, curr_length, _ = past[0].shape

    if curr_length > window_length and window_length > 0:
        ones_key_val_shape = (
                tuple(past[0].shape[:-2])
                + tuple([window_length])
                + tuple(past[0].shape[-1:])
        )

        zeros_key_val_shape = (
                tuple(past[0].shape[:-2])
                + tuple([curr_length - window_length])
                + tuple(past[0].shape[-1:])
        )

        ones_mask = torch.ones(ones_key_val_shape)
        ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
        ones_mask = ones_mask.permute(0, 1, 2, 4, 3)

        window_mask = torch.cat(
            (ones_mask, torch.zeros(zeros_key_val_shape)),
            dim=-2
        ).to(device)
    
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值